Merge branch 'master' into interface_16x8

This commit is contained in:
Elena Zhelezina 2020-03-03 10:18:40 +00:00 committed by GitHub
commit eaffdc0340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
642 changed files with 14462 additions and 6188 deletions

View File

@ -137,6 +137,7 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
# environment variable "TF_MKL_ROOT" every time before build.
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl --define=build_with_mkl_dnn_v1_only=true
build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
@ -238,7 +239,6 @@ build:c++1z --config=c++17
# Enable using platform specific build settings
build --enable_platform_specific_config
build --config=xla
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
build:linux --copt=-w

View File

@ -1,5 +1,6 @@
---
name: Feature Request about: Use this template for raising a feature request
name: Feature Request
about: Use this template for raising a feature request
labels: 'type:feature'
---

View File

@ -130,18 +130,20 @@ Build Type | Status
## Resources
* [TensorFlow.org](https://www.tensorflow.org)
* [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow official models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow examples](https://github.com/tensorflow/examples)
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
* [TensorFlow Examples](https://github.com/tensorflow/examples)
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
* [TensorFlow blog](https://blog.tensorflow.org)
* [TensorFlow Blog](https://blog.tensorflow.org)
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
* [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
* [TensorFlow roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow white papers](https://www.tensorflow.org/about/bib)
* [TensorBoard visualization toolkit](https://github.com/tensorflow/tensorboard)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)
Learn more about the
[TensorFlow community](https://www.tensorflow.org/community) and how to

View File

@ -1390,7 +1390,7 @@ def main():
else:
environ_cp['TF_CONFIGURE_IOS'] = '0'
if environ_cp.get('TF_ENABLE_XLA', 1):
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
write_to_bazelrc('build --config=xla')
set_action_env_var(

View File

@ -1213,10 +1213,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorHandle* ret_handle;
if (custom_device == nullptr) {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, context, &ret_handle);
std::move(t), device, device, context, &ret_handle);
} else {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, custom_device, context, &ret_handle);
std::move(t), custom_device, context, &ret_handle);
}
if (!status->status.ok()) {
return nullptr;

View File

@ -307,14 +307,12 @@ cc_library(
"transforms/runtime_type_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
"transforms/unroll_batch_matmul.cc",
"transforms/while_loop_outline.cc",
],
hdrs = [
"ir/tfl_ops_interface.h.inc",
"transforms/dilated_conv.h",
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
],
deps = [
":common",
@ -327,6 +325,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",

View File

@ -534,7 +534,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
let summary = "ArgMin operator";
let description = [{
Returns the index with the smallest value across dimensions of a tensor."
Returns the index with the smallest value across dimensions of a tensor.
a = [1, 10, 26.9, 2.8, 166.32, 62.3]
b = tf.math.argmin(input = a)
c = tf.keras.backend.eval(b)
@ -3181,19 +3181,19 @@ def TFL_LSTMOp :
Long short-term memory unit (LSTM) recurrent network layer.
The default non-peephole implementation is based on:
http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation,
S. Hochreiter and J. Schmidhuber. 'Long Short-Term Memory'. Neural Computation,
9(8):1735-1780, 1997.
The peephole implementation is based on:
https://research.google.com/pubs/archive/43905.pdf
Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory
recurrent neural network architectures for large scale acoustic modeling.
Hasim Sak, Andrew Senior, and Francoise Beaufays. 'Long short-term memory
recurrent neural network architectures for large scale acoustic modeling.'
INTERSPEECH, 2014.
The coupling of input and forget gate (CIFG) is based on:
http://arxiv.org/pdf/1503.04069.pdf
Greff et al. "LSTM: A Search Space Odyssey"
Greff et al. 'LSTM: A Search Space Odyssey'
The layer normalization is based on:
https://arxiv.org/pdf/1607.06450.pdf
Ba et al. Layer Normalization
Ba et al. 'Layer Normalization'
}];
let arguments = (

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_flags.pb.h"
@ -197,60 +198,60 @@ Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
return Status::OK();
}
} // namespace
Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
// Register any custom OpDefs.
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
return RegisterCustomBuiltinOps(extra_tf_opdefs);
}
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
const GraphDebugInfo& debug_info,
const GraphDef& input,
string* result) {
mlir::MLIRContext context;
GraphImportConfig specs;
mlir::TFL::QuantizationSpecs quant_specs;
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
quant_specs.inference_input_type =
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
mlir::TFL::QuantizationSpecs* quant_specs,
std::vector<string>* node_names,
std::vector<string>* node_dtypes,
std::vector<std::vector<int>>* node_shapes,
std::vector<double>* node_mins,
std::vector<double>* node_maxs) {
quant_specs->inference_input_type =
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
tensorflow::DataType inference_type =
ConvertIODataTypeToDataType(toco_flags.inference_type());
// Use non-float flag `inference_input_type` to override the `inference_type`
// because we have to apply quantization to satisfy that.
if (quant_specs.inference_input_type != tensorflow::DT_FLOAT) {
inference_type = quant_specs.inference_input_type;
if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
inference_type = quant_specs->inference_input_type;
}
for (auto& flag : model_flags.input_arrays()) {
node_names.push_back(flag.name());
node_names->push_back(flag.name());
// TOCO doesn't required `data_type` to be filled for every input.
// If it's not filled, make it an empty string so the importer will use
// the data type in the NodeDef.
auto toco_data_type = flag.data_type();
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
node_dtypes.push_back("");
node_dtypes->push_back("");
} else {
node_dtypes.push_back(
node_dtypes->push_back(
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
}
node_shapes.push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
flag.shape().dims().end()));
// Currently, only UINT8 and INT8 require inputs stats
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
TF_ASSIGN_OR_RETURN(
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
inference_type));
node_mins.push_back(min_max.first);
node_maxs.push_back(min_max.second);
node_mins->push_back(min_max.first);
node_maxs->push_back(min_max.second);
}
}
TF_RETURN_IF_ERROR(tensorflow::ParseInputArrayInfo(
node_names, node_dtypes, node_shapes, &specs.inputs));
if (mlir::TFL::GetInputNodeQuantSpecs(node_names, node_mins, node_maxs,
inference_type, &quant_specs)) {
if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
inference_type, quant_specs)) {
return errors::InvalidArgument("Failed to get input quant spec.");
}
@ -258,49 +259,34 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// quantization is enabled, `inference_type` and `inference_input_type` are
// not used by MLIR passes.
if (toco_flags.post_training_quantize()) {
quant_specs.weight_quantization = true;
quant_specs->weight_quantization = true;
if (toco_flags.quantize_to_float16()) {
quant_specs.inference_type = tensorflow::DT_HALF;
quant_specs.inference_input_type = tensorflow::DT_HALF;
quant_specs->inference_type = tensorflow::DT_HALF;
quant_specs->inference_input_type = tensorflow::DT_HALF;
} else {
quant_specs.inference_type = tensorflow::DT_QINT8;
quant_specs.inference_input_type = tensorflow::DT_QINT8;
quant_specs->inference_type = tensorflow::DT_QINT8;
quant_specs->inference_input_type = tensorflow::DT_QINT8;
}
}
// Parse output arrays.
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
model_flags.output_arrays().end());
TF_RETURN_IF_ERROR(
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
// Other flags.
if (toco_flags.has_default_ranges_min()) {
quant_specs.default_ranges.first = toco_flags.default_ranges_min();
quant_specs->default_ranges.first = toco_flags.default_ranges_min();
}
if (toco_flags.has_default_ranges_max()) {
quant_specs.default_ranges.second = toco_flags.default_ranges_max();
quant_specs->default_ranges.second = toco_flags.default_ranges_max();
}
return ::tensorflow::Status::OK();
}
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module,
mlir::TFL::QuantizationSpecs quant_specs,
string* result) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
bool emit_custom_ops = toco_flags.allow_custom_ops();
specs.prune_unused_nodes = true;
specs.convert_legacy_fed_inputs = true;
specs.graph_as_function = false;
specs.upgrade_legacy = true;
WarningUnusedFlags(model_flags, toco_flags);
// Register any custom OpDefs.
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
toco_flags.custom_opdefs().end());
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
@ -334,4 +320,52 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
return status;
}
} // namespace
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
const toco::TocoFlags& toco_flags,
const GraphDebugInfo& debug_info,
const GraphDef& input,
string* result) {
mlir::MLIRContext context;
GraphImportConfig specs;
mlir::TFL::QuantizationSpecs quant_specs;
// Parse input arrays.
std::vector<string> node_names;
std::vector<string> node_dtypes;
std::vector<std::vector<int>> node_shapes;
std::vector<double> node_mins;
std::vector<double> node_maxs;
// Populate quantization specs.
TF_RETURN_IF_ERROR(PopulateQuantizationSpecs(
model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
&node_shapes, &node_mins, &node_maxs));
TF_RETURN_IF_ERROR(tensorflow::ParseInputArrayInfo(
node_names, node_dtypes, node_shapes, &specs.inputs));
// Parse output arrays.
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
model_flags.output_arrays().end());
TF_RETURN_IF_ERROR(
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
specs.prune_unused_nodes = true;
specs.convert_legacy_fed_inputs = true;
specs.graph_as_function = false;
specs.upgrade_legacy = true;
WarningUnusedFlags(model_flags, toco_flags);
// Register all custom ops, including user-specified custom ops.
TF_RETURN_IF_ERROR(RegisterAllCustomOps(toco_flags));
TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
return ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
quant_specs, result);
}
} // namespace tensorflow

View File

@ -180,13 +180,12 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_19:%.*]] = constant unit
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_22:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
// CHECK: return [[VAL_21]], [[VAL_25:%.*]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: return [[VAL_21]], [[VAL_20]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: }
}
@ -221,15 +220,14 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_21:%.*]] = constant unit
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_25:%.*]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_22]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
// CHECK: return [[VAL_26]], [[VAL_24]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
// CHECK: return [[VAL_25]], [[VAL_24]], [[VAL_26]], [[VAL_27]], [[VAL_28]] : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: }
}
@ -264,13 +262,12 @@ func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor<?x8x8xf32>,
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_21:%.*]] = constant unit
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
// CHECK: return [[VAL_23]], [[VAL_27:%.*]], [[VAL_24]], [[VAL_25]], [[VAL_26]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: return [[VAL_23]], [[VAL_22]], [[VAL_24]], [[VAL_25]], [[VAL_26]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: }
}
@ -308,16 +305,92 @@ func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf3
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_22:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_20]], [[VAL_21]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_23:%.*]] = constant unit
// CHECK: [[VAL_24:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) ( {
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK: [[VAL_24:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK: [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_27:%.*]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_24]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_30:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_31:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
// CHECK: return [[VAL_28]], [[VAL_26]], [[VAL_29]], [[VAL_30]], [[VAL_31]] : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: [[VAL_30:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
// CHECK: return [[VAL_27]], [[VAL_26]], [[VAL_28]], [[VAL_29]], [[VAL_30]] : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: }
}
// -----
module {
func @inference_can_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) {
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse} : (tensor<?x8x8xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>)
%2 = "tf.Add"(%0, %1#1) : (tensor<f32>, tensor<?x8x10xf32>) -> tensor<?x8x10xf32>
return
}
func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_time_major_can_fuse([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
// CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK: [[VAL_19:%.*]] = constant unit
// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_22:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
// CHECK: return [[VAL_21]], [[VAL_20]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// -----
module {
func @inference_cannot_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) {
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_cannot_fuse} : (tensor<?x8x8xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>)
%2 = "tf.Add"(%0, %1#2) : (tensor<f32>, tensor<?x10xf32>) -> tensor<?x10xf32>
return
}
func @inference_standard_lstm_time_major_cannot_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}
// CHECK: func @inference_standard_lstm_time_major_cannot_fuse([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK: [[VAL_6:%.*]] = "tf.BatchMatMulV2"([[VAL_0]], [[VAL_3]]) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
// CHECK: [[VAL_7:%.*]] = "tf.Add"([[VAL_6]], [[VAL_5]]) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
// CHECK: [[VAL_8:%.*]] = "tf.BatchMatMulV2"([[VAL_7]], [[VAL_4]]) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
// CHECK: [[VAL_9:%.*]] = "tf.Add"([[VAL_8]], [[VAL_1]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
// CHECK: [[VAL_10:%.*]] = "tf.Add"([[VAL_8]], [[VAL_2]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
// CHECK: [[VAL_11:%.*]] = "tf.Add"([[VAL_1]], [[VAL_2]]) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: return [[VAL_11]], [[VAL_10]], [[VAL_11]], [[VAL_11]], [[VAL_12]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK: }
}

View File

@ -347,6 +347,16 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
Value element_shape = operands[0];
Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
// If the `element_shape` is a scalar, we know that it's dynamic shape
// and returns an error.
if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) {
if (shaped_type.getRank() == 0) {
op.emitError(
"requires element_shape to be 1D tensor during TF Lite "
"transformation pass");
return ConversionPattern::matchFailure();
}
}
DenseIntElementsAttr dense_elem_attr;
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {

View File

@ -61,7 +61,7 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
// pass.
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareCompositeFunctionsPass();
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass();
// Creates an instance of the TensorFlow Lite dialect ExtractOphint pass.
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass();

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
@ -35,10 +36,12 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
// NOLINTNEXTLINE
@ -91,14 +94,14 @@ class ConvertEmbeddedLookupFunc {
// body with the corresponding fused TFLite op. The replacement need not always
// be a fused op, though that is the primary use case.
class PrepareCompositeFunctionsPass
: public FunctionPass<PrepareCompositeFunctionsPass> {
: public ModulePass<PrepareCompositeFunctionsPass> {
public:
explicit PrepareCompositeFunctionsPass() {}
private:
void ConvertTFImplements(FuncOp func, StringAttr attr);
void ConvertTFAPIImplements(FuncOp func, StringAttr attr);
void runOnFunction() override;
void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module);
void runOnModule() override;
};
void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
@ -131,14 +134,54 @@ void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
}
}
LogicalResult CheckOutputConsumer(
Operation* call_op, int expected_num_outputs,
llvm::DenseSet<int> expected_consumer_indices) {
if (call_op->getNumResults() != expected_num_outputs) return failure();
for (int i = 0; i < expected_num_outputs; ++i) {
auto it = expected_consumer_indices.find(i);
if (it != expected_consumer_indices.end()) {
// Expected consumer.
if (call_op->getResult(i).use_empty()) return failure();
} else {
// Unexpected consumer.
if (!call_op->getResult(i).use_empty()) return failure();
}
}
return success();
}
LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) {
bool check_failed = false;
for (auto func : module.getOps<FuncOp>()) {
func.walk([&](Operation* op) {
auto call_op = dyn_cast_or_null<CallOpInterface>(op);
if (call_op && op->getAttrOfType<SymbolRefAttr>("f").getRootReference() ==
lstm_func.getName()) {
// Keras LSTM have 5 outputs.
// We should make sure only the second output is consumed.
if (failed(CheckOutputConsumer(call_op, 5, {1}))) check_failed = true;
}
});
}
if (check_failed) return failure();
return success();
}
void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
StringAttr attr) {
StringAttr attr,
ModuleOp module) {
// Keras lstm tf.api_implements usually has attribute like "lstm_abcde91...".
// TODO(b/147436982): we need to make sure that only the
// outputs(full sequence) is used, not the last_output, not the new_states.
// We will discard everything except the outputs.
// And the outputs is in the shape of [batch, time, units].
if (attr.getValue().startswith("lstm_")) {
// Check if the keras lstm can be fused, if not, we just don't do anything.
if (failed(CheckFusableKerasLstm(func, module))) return;
func.eraseBody();
func.addEntryBlock();
@ -148,26 +191,29 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
}
}
void PrepareCompositeFunctionsPass::runOnFunction() {
auto func = getFunction();
// We have two kinds of implements:
// 1) tf._implements.
// 2) tf.api_implements.
// We need to handle them separately.
auto tf_implements_attr = func.getAttrOfType<StringAttr>(kTFImplements);
if (tf_implements_attr) {
ConvertTFImplements(func, tf_implements_attr);
} else {
void PrepareCompositeFunctionsPass::runOnModule() {
auto module = getModule();
for (auto func : module.getOps<FuncOp>()) {
// We have two kinds of implements:
// 1) tf._implements.
// 2) tf.api_implements.
// We need to handle them separately.
auto tf_implements_attr = func.getAttrOfType<StringAttr>(kTFImplements);
if (tf_implements_attr) {
ConvertTFImplements(func, tf_implements_attr);
}
auto tf_api_implements_attr =
func.getAttrOfType<StringAttr>(kTFAPIImplements);
if (!tf_api_implements_attr) return;
// TODO(b/147536816): Keras lstm should set up the correct attributes.
ConvertTFAPIImplements(func, tf_api_implements_attr);
if (tf_api_implements_attr) {
// TODO(b/147536816): Keras lstm should set up the correct attributes.
ConvertTFAPIImplements(func, tf_api_implements_attr, module);
}
}
}
} // namespace
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareCompositeFunctionsPass() {
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
return std::make_unique<PrepareCompositeFunctionsPass>();
}

View File

@ -53,10 +53,10 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
#define DEBUG_TYPE "tf-tfl-legalization"
@ -652,8 +652,8 @@ void PrepareTFPass::runOnFunction() {
patterns.clear();
TFL::populateWithGenerated(ctx, &patterns);
if (unfold_batch_matmul_) {
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
}
patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative,
ConvertTFStridedSlice>(ctx);

View File

@ -692,7 +692,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
Value none = builder->create<mlir::ConstantOp>(
func_op.getLoc(), builder->getNoneType(), builder->getUnitAttr());
auto lstm = builder->create<mlir::TFL::LSTMOp>(
auto lstm = builder->create<mlir::TFL::UnidirectionalSequenceLSTMOp>(
func_op.getLoc(), result_type, /*input=*/input,
/*input_to_input_weights=*/weights_array->getResult(0),
/*input_to_forget_weights=*/weights_array->getResult(1),
@ -718,7 +718,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
/*cell_layer_norm_coefficients=*/none,
/*output_layer_norm_coefficients=*/none, builder->getStringAttr("TANH"),
builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0),
builder->getStringAttr("FULL"));
builder->getBoolAttr(true));
auto final_output = lstm.getResult();
if (!time_majored) {

View File

@ -198,6 +198,7 @@ cc_library(
"ir/tf_verifiers.h",
"transforms/bridge.h",
"transforms/passes.h",
"transforms/unroll_batch_matmul.h",
"@llvm-project//mlir:include/mlir/Analysis/CallInterfaces.h",
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
],
@ -297,6 +298,7 @@ cc_library(
"transforms/shape_inference.cc",
"transforms/shape_inference_pass.cc",
"transforms/sink_constant.cc",
"transforms/stack_ops_decomposition.cc",
"transforms/test_side_effect_analysis.cc",
"transforms/tf_device_assignment.cc",
"transforms/tpu_cluster_formation.cc",
@ -304,7 +306,9 @@ cc_library(
"transforms/tpu_dynamic_padding_mapper.cc",
"transforms/tpu_merge_variables_with_execute.cc",
"transforms/tpu_rewrite_pass.cc",
"transforms/tpu_sharding_identification_pass.cc",
"transforms/tpu_variable_runtime_reformatting.cc",
"transforms/unroll_batch_matmul.cc",
"translate/breakup-islands.cc",
"translate/control_to_executor_dialect.cc",
"translate/executor_to_control_dialect.cc",
@ -314,6 +318,7 @@ cc_library(
"transforms/bridge.h",
"transforms/passes.h",
"transforms/shape_inference.h",
"transforms/unroll_batch_matmul.h",
],
includes = ["include"],
deps = [
@ -335,6 +340,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite:validators",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@ -342,12 +348,14 @@ cc_library(
"//tensorflow/core/platform:random",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
@ -396,6 +404,7 @@ cc_library(
deps = [
":convert_tensor",
":convert_type",
":error_util",
":export_tf_dialect_op",
":export_utils",
":mangling_util",
@ -972,6 +981,18 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "mlir_local_var_op",
srcs = ["ops/mlir_local_var_op.cc"],
visibility = [
"//visibility:public",
],
deps = [
"//tensorflow/core:framework",
],
alwayslink = 1,
)
tf_gen_op_wrapper_py(
name = "gen_mlir_passthrough_op_py",
out = "gen_mlir_passthrough_op.py",

View File

@ -99,6 +99,24 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
addInterfaces<TFInlinerInterface>();
}
//===----------------------------------------------------------------------===//
// tf_device.launch
//===----------------------------------------------------------------------===//
// Checks if a tf_device.launch wraps a single operation and the single
// operation results are perfectly forwarded to the launch return.
bool LaunchOp::WrapsSingleOp() {
auto body = GetBody().without_terminator();
if (!has_single_element(body)) return false;
Operation& wrapped_op = *body.begin();
Operation* terminator = GetBody().getTerminator();
return wrapped_op.getNumResults() == terminator->getNumOperands() &&
std::equal(wrapped_op.getResults().begin(),
wrapped_op.getResults().end(),
terminator->getOperands().begin());
}
//===----------------------------------------------------------------------===//
// tf_device.return
//===----------------------------------------------------------------------===//

View File

@ -66,6 +66,7 @@ def TfDevice_LaunchOp : TfDevice_Op<"launch",
let extraClassDeclaration = [{
Block &GetBody() { return getOperation()->getRegion(0).front(); }
StringRef getDevice() { return device(); }
bool WrapsSingleOp();
}];
let builders = [
@ -129,26 +130,27 @@ def TfDevice_ReplicateOp :
let summary = "Wraps an N-way replicated computation.";
let description = [{
The region held by this operation represents a computation that is
replicated across multiple devices. The number of replications is based
on the `n` attribute. Explicit devices can be populated in the `devices`
attribute, and it must be a map of device alias and a list of explicit
device names or aliased device names from outer scope. Device name list
specifies devices on which replicated ops inside tf_device.replicate will
be executed. If tf_device.parallel_execute is inside replicate op region,
region inside replicate may represent computations across larger set of
devices. In that case, device alias can be used to specify device assignment
and replication of each concurrent execution (i.e. region) defined by
parallel_execute op. The size of the device name list must match `n`.
Within a replica, the execution semantics follow standard sequential behavior.
Ops in the replicate with no device assigned will have its device set to the
associated replicate device from `devices`. Operands are replicated inputs:
each group of `n` inputs corresponds to an input for a single individual
replica and is mapped to a single region argument. Inside one group the
operands are matching in order the `devices` attribute. Each replicated
input must have compatible shapes and types. Operands not replicated can
be implicitly captured by ops in the region. Results are replicated each
from the regions terminator.
The region held by this operation represents a computation that is replicated
across multiple devices. The number of replications is based on the `n`
attribute. Explicit devices can be populated in the `devices` attribute, and it
must be a mapping of device alias to list of explicit or aliased device names
from the outer scope. The device name map specifies devices on which replicated
ops inside tf_device.replicate will be executed. A tf_device.parallel_execute
inside the tf_device.replicate op region may be used to represent computations
across a larger set of devices. In that case, the device alias can be used to
specify device assignment and replication of each concurrent execution
(i.e. region) defined by tf_device.parallel_execute op. The size of each value
list in the device name map must match `n`. Within a replica, the execution
semantics follow standard sequential behavior. Ops in the tf_device.replicate
wrapped with a tf_device.launch will have its device set to the associated
replicated device from `devices` if the tf_device.launch refers to an aliased
device name. Otherwise the device already set in tf_device.launch is used
instead. Operands are replicated inputs: each group of `n` inputs corresponds to
an input for a single individual replica and is mapped to a single region
argument. Inside one group the operands are matching in order the `devices`
attribute. Each replicated input must have compatible shapes and types. Operands
not replicated can be implicitly captured by ops in the region. Results are
replicated each from the regions terminator.
For example:
```
@ -156,22 +158,45 @@ For example:
%1 = "tf.opB"() : () -> tensor<i32>
%2 = "tf.opC"() : () -> tensor<f32>
%3 = "tf.opD"() : () -> tensor<f32>
%4 = "tf.opE"() : () -> tensor<i1>
%output:4 = tf_device.replicate([%0, %1] as %input_0:tensor<i32>,
[%2, %3] as %input_1:tensor<f32>)
%4 = "tf.opE"() : () -> tensor<!tf.resource>
%5 = "tf.opF"() : () -> tensor<!tf.resource>
%6 = "tf.opG"() : () -> tensor<!tf.string>
%7 = "tf.opH"() : () -> tensor<!tf.string>
%8 = "tf.opI"() : () -> tensor<i1>
%output:8 = tf_device.replicate([%0, %1] as %input_0:tensor<i32>,
[%2, %3] as %input_1:tensor<f32>,
[%4, %5] as %input_2:tensor<!tf.resource>
[%6, %7] as %input_3:tensor<!tf.string>)
{n = 2 : i32,
devices = { DEVICE_ALIAS = ["/DEVICE:0", "/DEVICE:1"]}} {
// Inside the region, %input_0 corresponds to %0 for "/DEVICE:0" and %1 for
// "/DEVICE:1", and %input_1 corresponds to %2 for "/DEVICE:0" and %3 for
// "/DEVICE:1".
%f = "tf.opF"(%input_0, %4) : (tensor<i32>, tensor<i1>) -> tensor<i32>
%g = "tf.opG"(%input_1, %4) : (tensor<f32>, tensor<i1>) -> tensor<f32>
tf_device.return %f, %g : tensor<i32>, tensor<f32>
devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
// Inside the region, %0, %2, %4, and %6 corresponds to
// "/DEVICE:0"/"/DEVICE:2" and %1, %3, %5, and %7 corresponds to
// "/DEVICE:1"/"/DEVICE:3", depending on which device alias is used.
%j = "tf_device.launch"() ( {
%9 = "tf.opJ"(%input_0, %6) : (tensor<i32>, tensor<i1>) -> tensor<i32>
tf_device.return %9 : tensor<i32>
}) {device = "DEVICE_ALIAS_0"} : () -> tensor<i32>
%k = "tf_device.launch"() ( {
%10 = "tf.opK"(%input_1, %6) : (tensor<f32>, tensor<i1>) -> tensor<f32>
tf_device.return %10 : tensor<f32>
}) {device = "DEVICE_ALIAS_1"} : () -> tensor<f32>
%l = "tf_device.launch"() ( {
%11 = "tf.opL"(%input_2, %6) : (tensor<!tf.resource>, tensor<i1>)
-> tensor<!tf.resource>
tf_device.return %11 : tensor<!tf.resource>
}) {device = "/DEVICE:4"} : () -> tensor<f32>
%m = "tf.opM"(%input_3, %6) : (tensor<!tf.string>, tensor<i1>)
-> tensor<!tf.string>
tf_device.return %j, %k, %l, %m :
tensor<i32>, tensor<f32>, tensor<!tf.resource>, tensor<!tf.string>
}
// %output#0 corresponds to %f returned from "/DEVICE:0"
// %output#1 corresponds to %f returned from "/DEVICE:1"
// %output#2 corresponds to %g returned from "/DEVICE:0"
// %output#3 corresponds to %g returned from "/DEVICE:1"
// %output#0 corresponds to %j returned from "/DEVICE:0"
// %output#1 corresponds to %j returned from "/DEVICE:1"
// %output#2 corresponds to %k returned from "/DEVICE:2"
// %output#3 corresponds to %k returned from "/DEVICE:3"
// %output#4, %output#5 corresponds to %l and will be returned from "/DEVICE:4"
// %output#6, %output#7 corresponds to %m and will have no device set
```
}];
@ -194,13 +219,11 @@ For example:
let builders = [
OpBuilder<"Builder* builder, OperationState& state, int n, "
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
"llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>>"
" replicated_inputs, "
"llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs, "
"llvm::ArrayRef<Type> replica_output_types">,
OpBuilder<"Builder* builder, OperationState& state, int n, "
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
"llvm::ArrayRef<std::pair<Operation::operand_range, Type>>"
" replicated_inputs, "
"llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs, "
"Operation::result_type_range replica_output_types">
];

View File

@ -3926,6 +3926,21 @@ pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
}
def TF_MlirLocalVarOp : TF_Op<"MlirLocalVarOp", []> {
let summary = "Creates a handle to a in-scope variable.";
let description = [{
Used by internal passes for temporary representation of local state, which will
be eventually removed.
}];
let arguments = (ins);
let results = (outs
TF_ResourceTensor:$resource
);
}
def TF_MlirPassthroughOp : TF_Op<"MlirPassthroughOp", [NoSideEffect]> {
let summary = [{
Wraps an arbitrary MLIR computation expressed as a module with a main() function.
@ -6418,6 +6433,74 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_StackCloseV2Op : TF_Op<"StackCloseV2", []> {
let summary = "Delete the stack from its resource container.";
let description = [{
}];
let arguments = (ins
TF_ResourceTensor:$handle
);
let results = (outs);
}
def TF_StackPopV2Op : TF_Op<"StackPopV2", []> {
let summary = "Pop the element at the top of the stack.";
let description = [{
}];
let arguments = (ins
TF_ResourceTensor:$handle
);
let results = (outs
TF_Tensor:$elem
);
TF_DerivedResultTypeAttr elem_type = TF_DerivedResultTypeAttr<0>;
}
def TF_StackPushV2Op : TF_Op<"StackPushV2", []> {
let summary = "Push an element onto the stack.";
let description = [{
}];
let arguments = (ins
TF_ResourceTensor:$handle,
TF_Tensor:$elem,
DefaultValuedAttr<BoolAttr, "false">:$swap_memory
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
}
def TF_StackV2Op : TF_Op<"StackV2", []> {
let summary = "A stack that produces elements in first-in last-out order.";
let description = [{
}];
let arguments = (ins
I32Tensor:$max_size,
TypeAttr:$elem_type,
StrAttr:$stack_name
);
let results = (outs
TF_ResourceTensor:$handle
);
}
def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Stops gradient computation.";
@ -6701,6 +6784,24 @@ occurred during compilation.
);
}
def TF_TPUCompileSucceededAssertOp : TF_Op<"TPUCompileSucceededAssert", []> {
let summary = [{
Asserts that compilation succeeded. This op produces no output and closes the
}];
let description = [{
device during failure to ensure all pending device interactions fail.
'compilation_status' is a serialized CompilationResultProto.
}];
let arguments = (ins
TF_StrTensor:$compilation_status
);
let results = (outs);
}
def TF_TPUCopyWithLayoutOp : TF_Op<"TPUCopyWithLayout", [NoSideEffect]> {
let summary = "Op that copies host tensor to device with specified layout.";
@ -6919,6 +7020,38 @@ is the corresponding input gradient.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_TensorListConcatV2Op : TF_Op<"TensorListConcatV2", [NoSideEffect]> {
let summary = "Concats all tensors in the list along the 0th dimension.";
let description = [{
Requires that all tensors have the same shape except the first dimension.
input_handle: The input list.
element_shape: The shape of the uninitialized elements in the list. If the first
dimension is not -1, it is assumed that all list elements have the same
leading dim.
leading_dims: The list of leading dims of uninitialized list elements. Used if
the leading dim of input_handle.element_shape or the element_shape input arg
is not already set.
tensor: The concated result.
lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient.
}];
let arguments = (ins
TF_VariantTensor:$input_handle,
TF_I32OrI64Tensor:$element_shape,
I64Tensor:$leading_dims
);
let results = (outs
TF_Tensor:$tensor,
I64Tensor:$lengths
);
TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<1>;
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
}
def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> {
let summary = [{
Creates a TensorList which, when stacked, has the value of `tensor`.
@ -6980,6 +7113,33 @@ length: the number of tensors in the list
);
}
def TF_TensorListPopBackOp : TF_Op<"TensorListPopBack", [NoSideEffect]> {
let summary = [{
Returns the last element of the input list as well as a list with all but that element.
}];
let description = [{
Fails if the list is empty.
input_handle: the input list
tensor: the withdrawn last element of the list
element_dtype: the type of elements in the list
element_shape: the shape of the output tensor
}];
let arguments = (ins
TF_VariantTensor:$input_handle,
I32Tensor:$element_shape
);
let results = (outs
TF_VariantTensor:$output_handle,
TF_Tensor:$tensor
);
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<1>;
}
def TF_TensorListPushBackOp : TF_Op<"TensorListPushBack", [NoSideEffect]> {
let summary = [{
Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`.
@ -7823,3 +7983,38 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF__TPUCompileMlirOp : TF_Op<"_TPUCompileMlir", []> {
let summary = [{
Compiles a computations for execution on one or more TPU devices.
}];
let description = [{
For the internal use of the distributed TPU compiler. Note that currently only
single TPU device is supported.
'mlir_module' is a serialized MLIR module with a `main` function that contains
target computation.
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
known statically at TPUReplication rewrite time.
'metadata' is a serialized TPUCompileMetadataProto describing
the shapes and types of the inputs to the computation, as well as a mapping onto
the TPU pod topology.
'program' output is a string key that is passed to the _TPUExecute op and
used to look up the program in the compilation cache.
}];
let arguments = (ins
Variadic<I64Tensor>:$dynamic_shapes,
StrAttr:$mlir_module,
StrAttr:$metadata
);
let results = (outs
TF_StrTensor:$compilation_status,
TF_StrTensor:$program
);
TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>;
}

View File

@ -1,4 +1,4 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef CUDA_CUDA_CONFIG_H_
#define CUDA_CUDA_CONFIG_H_
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#define TF_CUDA_CAPABILITIES CudaVersion("3.0"), CudaVersion("6.0")
namespace tensorflow {
#define TF_CUDA_VERSION "10.1"
#define TF_CUDA_LIB_VERSION "10"
#define TF_CUDNN_VERSION "7"
REGISTER_OP("MlirLocalVarOp")
.Output("resource: resource")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"(Creates a handle to a in-scope variable.
Used by internal passes for temporary representation of local state, which will
be eventually removed.)");
#define TF_CUDA_TOOLKIT_PATH "/usr/local/cuda-10.1"
#endif // CUDA_CUDA_CONFIG_H_
} // namespace tensorflow

View File

@ -3,7 +3,10 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
licenses(["notice"])
glob_lit_tests(
data = [":test_utilities"],
data = [
":debug_info_files",
":test_utilities",
],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["pbtxt"],
)
@ -18,3 +21,13 @@ filegroup(
"@llvm-project//llvm:not",
],
)
# Bundle together all the debug info files that are used by the tests.
filegroup(
name = "debug_info_files",
srcs = glob(
[
"**/*.debug",
],
),
)

View File

@ -0,0 +1,62 @@
# RUN: not tf-mlir-translate -graphdef-to-mlir -tf-input-arrays=x,y -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=2:3 -tf-output-arrays=x_y_sum %s --tf-debug-info=%s.debug -o - 2>&1 | FileCheck %s
# Checks that source debug information is used in the output error message.
# CHECK: Graph import failed: Invalid argument: Dimensions must be equal
# CHECK: math_ops.add(x, y, name='x_y_sum')
# CHECK: build_graph(out_dir)
node: {
name: "x"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "y"
op: "Placeholder"
attr: {
key: "shape"
value: {
shape: {
dim: {
size: -1
}
}
}
}
attr: {
key: "dtype"
value: {
type: DT_INT32
}
}
}
node: {
name: "x_y_sum"
op: "Add"
input: "x"
input: "y"
attr: {
key: "T"
value: {
type: DT_INT32
}
}
}
versions: {
producer: 321
}

View File

@ -0,0 +1,28 @@
files : [ "org_tensorflow/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/error-message-with-source-info.pbtxt.fake_py.debug"]
traces: {
key : "x@"
value: {
file_line_cols: {
line : 1
}
}
}
traces: {
key : "x_y_sum@"
value: {
file_line_cols: {
line : 3
}
file_line_cols: {
line : 4
}
}
}
traces: {
key : "y@"
value: {
file_line_cols: {
line : 2
}
}
}

View File

@ -0,0 +1,4 @@
x = value
y = value
math_ops.add(x, y, name='x_y_sum')
build_graph(out_dir)

View File

@ -25,13 +25,16 @@ func @controls_per_replica() {
// CHECK: %[[ISLAND_2:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_0]], %[[ISLAND_1]])
// Tests devices are not set if no devices were defined in replicate.
// Tests devices are not remapped if no devices were defined in replicate.
// CHECK-LABEL: func @no_devices
func @no_devices() {
tf_executor.graph {
%0 = tf_executor.island {
tf_device.replicate {n = 2 : i32} {
"tf.opA"() : () -> ()
"tf_device.launch"() ( {
"tf.opA"() : () -> ()
tf_device.return
}) {device = "CORE_0"} : () -> ()
tf_device.return
}
tf_executor.yield
@ -41,19 +44,22 @@ func @no_devices() {
return
}
// CHECK: "tf.opA"
// CHECK-NOT: device
// CHECK: "tf.opA"
// CHECK-NOT: device
// CHECK: "tf.opA"
// CHECK: device = "CORE_0"
// CHECK: "tf.opA"
// CHECK: device = "CORE_0"
// Tests devices are not set if op already has a device assigned.
// Tests devices are not remapped if device is not in replicate devices.
// CHECK-LABEL: func @no_override_device
func @no_override_device() {
tf_executor.graph {
%0 = tf_executor.island {
tf_device.replicate {n = 2 : i32, devices = {CORE_0 = ["/CPU:0", "/GPU:1"]}} {
"tf.opA"() {device = "/TPU:2"} : () -> ()
"tf_device.launch"() ( {
"tf.opA"() : () -> ()
tf_device.return
}) {device = "/TPU:2"} : () -> ()
tf_device.return
}
tf_executor.yield
@ -63,20 +69,22 @@ func @no_override_device() {
return
}
// CHECK: "tf.opA"
// CHECK-SAME: device = "/TPU:2"
// CHECK: "tf.opA"
// CHECK-SAME: device = "/TPU:2"
// CHECK: "tf.opA"
// CHECK: device = "/TPU:2"
// CHECK: "tf.opA"
// CHECK: device = "/TPU:2"
// Tests devices are not set if op is not of the TF dialect.
// CHECK-LABEL: func @no_device_non_tf_ops
func @no_device_non_tf_ops() {
// Tests devices are remapped if device is in replicate devices.
// CHECK-LABEL: func @remap_device
func @remap_device() {
tf_executor.graph {
%0 = tf_executor.island {
tf_device.replicate {n = 2 : i32, devices = {CORE_0 = ["/CPU:0", "/GPU:1"]}} {
"test.opA"() {device = "/TPU:2"} : () -> ()
"test.opB"() : () -> ()
"tf_device.launch"() ( {
"tf.opA"() : () -> ()
tf_device.return
}) {device = "CORE_0"} : () -> ()
tf_device.return
}
tf_executor.yield
@ -86,14 +94,10 @@ func @no_device_non_tf_ops() {
return
}
// CHECK: "test.opA"
// CHECK-SAME: device = "/TPU:2"
// CHECK: "test.opB"
// CHECK-NOT: device
// CHECK: "test.opA"
// CHECK-SAME: device = "/TPU:2"
// CHECK: "test.opB"
// CHECK-NOT: device
// CHECK: "tf.opA"
// CHECK: device = "/CPU:0"
// CHECK: "tf.opA"
// CHECK: device = "/GPU:1"
// Tests unused per replica island are added as a control dependency to the
@ -104,7 +108,7 @@ func @unused_replica_control(%arg0: tensor<i1>, %arg1: tensor<i1>) {
%0 = tf_executor.graph {
%1 = tf_executor.ControlTrigger {}
%2:2 = tf_executor.island(%1) {
%3:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>) {n = 2 : i32, devices = {CORE_0 = ["/CPU:0", "/GPU:1"]}} {
%3:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>) {n = 2 : i32} {
%4 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
%5 = "tf.opB"(%4) : (tensor<i1>) -> tensor<i1>
tf_device.return %4, %5 : tensor<i1>, tensor<i1>
@ -119,15 +123,11 @@ func @unused_replica_control(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger
// CHECK: %[[ISLAND_0:[a-z_0-9]*]]:2, %{{.*}} = tf_executor.island(%[[CT]])
// CHECK: %[[OP_A_0:[0-9]*]] = "tf.opA"(%[[ARG_0]])
// CHECK-SAME: device = "/CPU:0"
// CHECK: %[[OP_B_0:[0-9]*]] = "tf.opB"(%[[OP_A_0]])
// CHECK-SAME: device = "/CPU:0"
// CHECK: tf_executor.yield %[[OP_A_0]], %[[OP_B_0]]
// CHECK: %[[ISLAND_1:[a-z_0-9]*]]:2, %[[ISLAND_1_control:[a-z_0-9]*]] = tf_executor.island(%[[CT]])
// CHECK: %[[OP_A_1:[0-9]*]] = "tf.opA"(%[[ARG_1]])
// CHECK-SAME: device = "/GPU:1"
// CHECK: %[[OP_B_1:[0-9]*]] = "tf.opB"(%[[OP_A_1]])
// CHECK-SAME: device = "/GPU:1"
// CHECK: tf_executor.yield %[[OP_A_1]], %[[OP_B_1]]
// CHECK: %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island(%[[ISLAND_1_control]])
// CHECK: tf_executor.yield %[[ISLAND_0]]#0

View File

@ -343,7 +343,7 @@ func @launch_with_loop() -> () {
}
func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>) {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
// expected-error @+1 {{Resource used in while loop is only supported when the resource input and output alias each other in the loop body}}
// expected-error @+1 {{resource used in while loop is only supported when the resource input and output alias each other in the loop body}}
return %0 : tensor<*x!tf.resource<tensor<f32>>>
}
func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
@ -367,7 +367,7 @@ func @launch_with_loop() -> () {
return
}
func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>) {
// expected-error @+1 {{Found unsupported operations on resource.}}
// expected-error @+1 {{found unsupported operations on resource.}}
"tf._UnknownOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
}
@ -399,7 +399,7 @@ func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.re
func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
%read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
%constant = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// expected-error @+1 {{Found resource write in loop condition.}}
// expected-error @+1 {{found resource write in loop condition.}}
"tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
return %read : tensor<f32>
}
@ -524,7 +524,7 @@ func @launch_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
%2 = "tf_device.launch"() ( {
// expected-error @+1 {{Unsupported tf.IfOp output: resource does not alias a single input.}}
// expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}}
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
output_shapes = ["tfshape$"], is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
@ -650,7 +650,7 @@ func @launch_with_stateful_partitioned_call() -> () {
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
return
}
// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}}
// expected-error @+1 {{unsupported function call: resource return value does not alias an input.}}
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
%0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource<tensor<f32>>>
return %0 : tensor<*x!tf.resource<tensor<f32>>>

View File

@ -0,0 +1,254 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-stack-ops-decomposition | FileCheck %s -dump-input-on-failure
// Tests simple scalar stack operations without control flow.
// CHECK-LABEL: func @main
func @main() -> tensor<f32> {
// CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor<i32>}
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<10xf32>>>
// CHECK-NEXT: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<1xi32>>>
// CHECK-NEXT: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[ZERO]])
// CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[CAST_ZERO:.*]] = "tf.Cast"(%[[ZERO_SCALAR]]) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[CAST_ZERO]], %[[CONST10]]) : (tensor<f32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[BROADCAST]])
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
%id = "tf.Identity"(%stack) : (tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK-NEXT: %[[PUSHVAL:.*]] = "tf._SomeOp"()
%elem = "tf._SomeOp"() : () -> tensor<f32>
// CHECK-NEXT: %[[READ_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
// CHECK-NEXT: %[[READ_SIZE:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
// CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[PUSHVAL]], %[[UPDATE_SHAPE]]) : (tensor<f32>, tensor<1xi32>) -> tensor<1xf32>
// CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_VAL]], %[[UPDATE_SLICE]], %[[READ_SIZE]]) : (tensor<10xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) : (tensor<!tf.resource<tensor<10xf32>>>, tensor<10xf32>) -> ()
// CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: %[[NEW_SIZE:.*]] = "tf.AddV2"(%[[READ_SIZE]], %[[CONST1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[NEW_SIZE]]) : (tensor<!tf.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
%push = "tf.StackPushV2"(%id, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
%pop = "tf.StackPopV2"(%stack) : (tensor<!tf.resource>) -> tensor<f32>
// CHECK-NEXT: %[[READ_VAL1:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
// CHECK-NEXT: %[[READ_SIZE1:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
// CHECK-NEXT: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: %[[SUB:.*]] = "tf.Sub"(%[[READ_SIZE1]], %[[CONST1_1]])
// CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[READ_VAL1]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32>
// CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
// CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[SUB]]) : (tensor<!tf.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
// CHECK-NEXT: return %[[ELEM]] : tensor<f32>
return %pop : tensor<f32>
}
// -----
// Tests simple non-scalar stack operations without control flow.
// CHECK-LABEL: func @main
func @main() -> tensor<2xi32> {
// CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<10x2xi32>>>
// CHECK-NEXT: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<1xi32>>>
// CHECK-NEXT: %[[ZERO_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[ZERO_SIZE]]) : (tensor<!tf.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
// CHECK-NEXT: %[[ZERO_CONST:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[STACK_SHAPE:.*]] = "tf.Const"() {value = dense<[10, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[ZERO_CONST]], %[[STACK_SHAPE]]) : (tensor<i32>, tensor<2xi32>) -> tensor<10x2xi32>
// CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[BROADCAST]]) : (tensor<!tf.resource<tensor<10x2xi32>>>, tensor<10x2xi32>) -> ()
%stack = "tf.StackV2"(%size) {elem_type = i32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
// CHECK-NEXT: %[[PUSH_VAL:.*]] = "tf._SomeOp"() : () -> tensor<2xi32>
%elem = "tf._SomeOp"() : () -> tensor<2xi32>
// CHECK-NEXT: %[[STACK_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) : (tensor<!tf.resource<tensor<10x2xi32>>>) -> tensor<10x2xi32>
// CHECK-NEXT: %[[STACK_SIZE:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) : (tensor<!tf.resource<tensor<1xi32>>>) -> tensor<1xi32>
// CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[PUSH_VAL]], %[[UPDATE_SHAPE]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
// CHECK-NEXT: %[[ZERO_INDS:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: %[[CONCAT_DIM:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
// CHECK-NEXT: %[[CONCAT_OFFETS:.*]] = "tf.ConcatV2"(%[[STACK_SIZE]], %[[ZERO_INDS]], %[[CONCAT_DIM]]) : (tensor<1xi32>, tensor<1xi32>, tensor<i32>) -> tensor<2xi32>
// CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[STACK_VAL]], %[[UPDATE_SLICE]], %[[CONCAT_OFFETS]]) : (tensor<10x2xi32>, tensor<1x2xi32>, tensor<2xi32>) -> tensor<10x2xi32>
// CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) : (tensor<!tf.resource<tensor<10x2xi32>>>, tensor<10x2xi32>) -> ()
// CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: %[[NEW_SIZE:.*]] = "tf.AddV2"(%[[STACK_SIZE]], %[[CONST1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[NEW_SIZE]]) : (tensor<!tf.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<2xi32>) -> tensor<2xi32>
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
// CHECK-NEXT: return %[[PUSH_VAL]] : tensor<2xi32>
return %push : tensor<2xi32>
}
// -----
// Tests while loop.
// CHECK-LABEL: func @main
func @main() -> () {
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// CHECK-NOT: tf.Stack
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
%1:2 = "tf.While"(%stack, %max_size) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<!tf.resource>, tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>)
// CHECK: "tf.Slice"
%pop = "tf.StackPopV2"(%1#0) : (tensor<!tf.resource>) -> tensor<f32>
// CHECK-NOT: tf.Stack
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
// CHECK: return
return
}
// CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
func @while_body(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>) {
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
%sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<f32>
// CHECK-NOT: "tf.StackPushV2"
// CHECK: "tf.XlaDynamicUpdateSlice"
// CHECK-NOT: "tf.StackPushV2"
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]]
return %arg0, %sub : tensor<!tf.resource>, tensor<i32>
}
// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
func @while_cond(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> tensor<i32> {
// CHECK-NEXT: return %[[CARG1]]
return %arg1 : tensor<i32>
}
// -----
// Tests IfOp.
// CHECK-LABEL: func @main
func @main(%arg0: tensor<i1>) -> () {
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// CHECK-NOT: tf.Stack
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
%if_op = "tf.If"(%arg0, %stack) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
: (tensor<i1>, tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: "tf.Slice"
%pop = "tf.StackPopV2"(%if_op) : (tensor<!tf.resource>) -> tensor<f32>
// CHECK-NOT: tf.Stack
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
// CHECK: return
return
}
// CHECK: func @if_then(%[[TARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[TARG1:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
func @if_then(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
%elem = "tf._SomeOp"() : () -> tensor<f32>
// CHECK-NOT: "tf.StackPushV2"
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
// CHECK: "tf.AssignVariableOp"(%[[TARG0:.*]], %[[UPDATE]])
// CHECK: "tf.AssignVariableOp"(%[[EARG1:.*]],
// CHECK-NOT: "tf.StackPushV2"
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @if_else(%[[EARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[EARG1:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
func @if_else(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
// CHECK-NOT: "tf.StackPushV2"
// CHECK: "tf.Slice"
// CHECK: "tf.AssignVariableOp"(%[[EARG1:.*]],
// CHECK-NOT: "tf.StackPushV2"
%pop = "tf.StackPopV2"(%arg0) : (tensor<!tf.resource>) -> tensor<f32>
return %arg0 : tensor<!tf.resource>
}
// -----
// Tests PartitionedCall/StatefulPartitionedCall.
// CHECK-LABEL: func @main
func @main(%arg0: tensor<i1>) -> () {
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// CHECK-NOT: tf.Stack
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
// CHECK: "tf.StatefulPartitionedCall"
// CHECK-SAME: f = @callee_stack_decomposed
%call = "tf.StatefulPartitionedCall"(%stack, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<!tf.resource>, tensor<i1>) -> tensor<!tf.resource>
// CHECK: "tf.PartitionedCall"
// CHECK-SAME: f = @callee_stack_decomposed
%call2 = "tf.PartitionedCall"(%stack, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<!tf.resource>, tensor<i1>) -> tensor<!tf.resource>
// CHECK: "tf.Slice"
%pop = "tf.StackPopV2"(%call) : (tensor<!tf.resource>) -> tensor<f32>
// CHECK-NOT: tf.Stack
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
// CHECK: return
return
}
// CHECK: func @callee(%[[AARG0:.*]]: tensor<!tf.resource>, %[[AARG1:.*]]: tensor<i1>) -> tensor<!tf.resource>
func @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> {
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
// CHECK: tf.StackPushV2"
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @callee_stack_decomposed(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
// CHECK-NOT: "tf.StackPushV2"
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
// CHECK: "tf.AssignVariableOp"(%[[TARG0:.*]], %[[UPDATE]])
// CHECK: "tf.AssignVariableOp"(%[[EARG1:.*]],
// CHECK-NOT: "tf.StackPushV2"
// -----
// Tests that the pass reports error on unknown stack size.
func @main(%arg0: tensor<i32>) -> tensor<2xi32> {
// expected-error @+1 {{max size of stack is not a constant.}}
%stack = "tf.StackV2"(%arg0) {elem_type = i32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
%elem = "tf._SomeOp"() : () -> tensor<2xi32>
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<2xi32>) -> tensor<2xi32>
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
return %push : tensor<2xi32>
}
// -----
// Tests that the pass reports error on unknown element shape.
func @main(%arg0: tensor<i32>) -> () {
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{cannot infer element shape of stack.}}
%stack = "tf.StackV2"(%max_size) {elem_type = i32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
%elem = "tf._SomeOp"() : () -> tensor<*xi32>
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<*xi32>) -> tensor<*xi32>
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
return
}
// -----
// Tests that the pass reports error on ambiguous stack.
func @main(%arg0: tensor<i1>) -> () {
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
%stack2 = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s2"} : (tensor<i32>) -> tensor<!tf.resource>
%if_op = "tf.If"(%arg0, %stack, %stack2) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
: (tensor<i1>, tensor<!tf.resource>, tensor<!tf.resource>) -> tensor<!tf.resource>
// expected-error @+1 {{unknown stack.}}
%pop = "tf.StackPopV2"(%if_op) : (tensor<!tf.resource>) -> tensor<f32>
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
// CHECK: return
return
}
func @if_then(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
%elem = "tf._SomeOp"() : () -> tensor<f32>
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
return %arg0 : tensor<!tf.resource>
}
func @if_else(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
%elem = "tf._SomeOp"() : () -> tensor<f32>
%push = "tf.StackPushV2"(%arg1, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
return %arg1 : tensor<!tf.resource>
}

View File

@ -4,12 +4,16 @@
// CHECK: func @non_replicated(%[[ARG0:.*]]: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32>
func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
// CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext"
@ -17,11 +21,20 @@ func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
// CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"}
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
// CHECK: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1) {device = "/device:TPU:0"}
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
return %3 : tensor<i32>
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1)
%execute = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device = "/device:TPU:0"} : () -> tensor<i32>
return %execute : tensor<i32>
}
// -----
@ -31,22 +44,34 @@ func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}
// CHECK-LABEL: func @multiple_compile_uses
func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%execute0 = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device = "/device:TPU:0"} : () -> tensor<i32>
%4:2 = "tf._UnKnownOp_"() : () -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
%5 = "tf.TPUExecute"(%4#0, %4#1, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
return %5 : tensor<i32>
%execute1 = "tf_device.launch"() ( {
%5 = "tf.TPUExecute"(%4#0, %4#1, %compile#1)
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %5 : tensor<i32>
}) {device = "/device:TPU:0"} : () -> tensor<i32>
return %execute1 : tensor<i32>
}
// -----
@ -55,19 +80,28 @@ func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device:
// CHECK-LABEL: func @on_tpu_iter
func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) -> tensor<i32> {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:TPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
return %3 : tensor<i32>
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%execute = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device = "/device:TPU:0"} : () -> tensor<i32>
return %execute : tensor<i32>
}
// -----
@ -76,18 +110,27 @@ func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) -
// CHECK-LABEL: func @unsupported_ops
func @unsupported_ops(%arg0: tensor<3x3x1x32xf32>) -> tensor<i32> {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
%2 = "tf._Unknown_"() : () -> tensor<3x3x1x32xf32>
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
%3 = "tf.TPUExecute"(%arg0, %2, %1#1) {device = "/device:TPU:0"}
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
return %3 : tensor<i32>
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%execute = "tf_device.launch"() ( {
%3 = "tf.TPUExecute"(%arg0, %2, %compile#1)
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %3 : tensor<i32>
}) {device = "/device:TPU:0"} : () -> tensor<i32>
return %execute : tensor<i32>
}
// -----
@ -99,12 +142,16 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) ->
// CHECK: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
@ -114,13 +161,19 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) ->
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
// CHECK-DAG: %[[COPY2:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#0, %[[LAYOUT0]]) {device = "/device:TPU:1"}
// CHECK-DAG: %[[COPY3:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#1, %[[LAYOUT1]]) {device = "/device:TPU:1"}
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
// CHECK: tf_device.replicate([%[[COPY0]], %[[COPY2]]] as %[[R0:.*]]: tensor<3x3x1x32xf32>, [%[[COPY1]], %[[COPY3]]] as %[[R1:.*]]: tensor<3x3x1x32xf32>)
%5:2 = tf_device.replicate([%2#0, %3#0] as %r0: tensor<3x3x1x32xf32>, [%2#1, %3#1] as %r1: tensor<3x3x1x32xf32>)
{n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} {
// CHECK: "tf.TPUExecute"(%[[R0]], %[[R1]], %[[COMPILE]]#1)
%4 = "tf.TPUExecute"(%r0, %r1, %1#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %4 : tensor<i32>
%execute = "tf_device.launch"() ( {
%4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %4 : tensor<i32>
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i32>
tf_device.return %execute : tensor<i32>
}
return %5#0 : tensor<i32>
}
@ -131,20 +184,29 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) ->
// CHECK-LABEL: func @inside_replicated
func @inside_replicated(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> tensor<i32> {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
%compile:2 = "tf_device.launch"() ( {
%1:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
// CHECK-NOT: "tf.TPUGetLayoutOp"
// CHECK-NOT: "tf.TPUCopyWithLayout"
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%5:2 = tf_device.replicate([%arg0, %arg1] as %r0: tensor<*x!tf.resource>)
{n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} {
%2:2 = "tf.IteratorGetNext"(%r0)
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
%4 = "tf.TPUExecute"(%2#0, %2#1, %1#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %4 : tensor<i32>
%execute = "tf_device.launch"() ( {
%4 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
tf_device.return %4 : tensor<i32>
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i32>
tf_device.return %execute : tensor<i32>
}
return %5#0 : tensor<i32>
}

View File

@ -22,16 +22,21 @@ func @merge_same_device_variables(
%read0 = "tf.ReadVariableOp"(%id0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
%read2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource<tensor<16xf32>>>) -> tensor<16xf32>
// CHECK-NEXT: %[[EXE:.*]] = "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[ARG_3]])
// CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[ARG_3]])
// CHECK-SAME: device_var_reads_indices = [0, 1],
// CHECK-SAME: device_var_updates_indices = [0, -1]
%execute:2 = "tf.TPUExecute"(%read0, %read1, %read2, %arg3) {
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<16xf32>],
Tresults = [tensor<32xf32>, tensor<16xf32>],
device = "/job:localhost/replica:0/task:0/device:TPU:0"}
: (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<16xf32>)
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_2]], %[[EXE]])
%execute:2 = "tf_device.launch"() ( {
%0:2 = "tf.TPUExecute"(%read0, %read1, %read2, %arg3) {
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<16xf32>],
Tresults = [tensor<32xf32>, tensor<16xf32>]}
: (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<16xf32>)
tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<16xf32>
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<16xf32>)
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
"tf.AssignVariableOp"(%id0, %execute#0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_2]], %[[EXE]])
"tf.AssignVariableOp"(%arg2, %execute#1) : (tensor<*x!tf.resource<tensor<16xf32>>>, tensor<16xf32>) -> ()
// CHECK-NEXT: tf_executor.yield
tf_executor.yield
@ -61,12 +66,18 @@ func @merge_replicated_variables(
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// CHECK-NEXT: tf_device.replicate([%[[ARG_2]], %[[ARG_3]]] as %[[R_ARG:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>)
tf_device.replicate([%arg2, %arg3] as %r: tensor<*x!tf.resource<tensor<32xf32>>>) {n = 2 : i32} {
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[R_ARG]], %[[ARG_1]])
// CHECK-SAME: device_var_reads_indices = [1],
// CHECK-SAME: device_var_updates_indices = [0]
%read1 = "tf.ReadVariableOp"(%r) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%execute = "tf.TPUExecute"(%read0, %read1, %arg1)
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> tensor<32xf32>
%execute = "tf_device.launch"() ( {
%0 = "tf.TPUExecute"(%read0, %read1, %arg1)
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> tensor<32xf32>
tf_device.return %0 : tensor<32xf32>
}) {device = ""} : () -> tensor<32xf32>
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: }) {device = ""}
"tf.AssignVariableOp"(%r, %execute) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// CHECK-NEXT: tf_device.return
tf_device.return
@ -114,15 +125,20 @@ func @interferencing_accesses(
"tf.AssignVariableOp"(%arg5, %arg6) : (tensor<*x!tf.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
%read2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[ARG_3]])
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[ARG_3]])
// CHECK-SAME: device_var_reads_indices = [1, 2],
// CHECK-SAME: device_var_updates_indices = [1, -1]
%execute:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %arg3) {
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>],
Tresults = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>],
device = "/job:localhost/replica:0/task:0/device:TPU:0"}
: (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor<!tf.string>)
-> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
%execute:3 = "tf_device.launch"() ( {
%0:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %arg3) {
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>],
Tresults = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>]}
: (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor<!tf.string>)
-> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
tf_device.return %0#0, %0#1, %0#2 : tensor<32xf32>, tensor<64xf32>, tensor<8xf32>
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
"tf.AssignVariableOp"(%arg1, %execute#1) : (tensor<*x!tf.resource<tensor<64xf32>>>, tensor<64xf32>) -> ()
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0)
"tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
@ -156,11 +172,16 @@ func @do_not_merge_multi_read(
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// CHECK-NEXT: %[[READ_1:.*]] = "tf.ReadVariableOp"(%[[ARG_0]])
%read1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// CHECK-NEXT: %[[EXE:.*]] = "tf.TPUExecute"(%[[READ_0]], %[[READ_1]], %[[ARG_1]])
%execute = "tf.TPUExecute"(%read0, %read1, %arg1) {
Targs = [tensor<32xf32>, tensor<32xf32>], Tresults = [tensor<32xf32>],
device = "/job:localhost/replica:0/task:0/device:TPU:0"}
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> (tensor<32xf32>)
// CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[READ_1]], %[[ARG_1]])
%execute = "tf_device.launch"() ( {
%0 = "tf.TPUExecute"(%read0, %read1, %arg1) {
Targs = [tensor<32xf32>, tensor<32xf32>], Tresults = [tensor<32xf32>]}
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> (tensor<32xf32>)
tf_device.return %0 : tensor<32xf32>
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<32xf32>
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]])
"tf.AssignVariableOp"(%arg0, %execute) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// CHECK-NEXT: tf_executor.yield
@ -187,11 +208,16 @@ func @do_not_merge_multi_assign(
%island = tf_executor.island {
// CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]])
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf.TPUExecute"(%[[READ_0]], %[[ARG_1]])
%execute:2 = "tf.TPUExecute"(%read0, %arg1) {
Targs = [tensor<32xf32>], Tresults = [tensor<32xf32>, tensor<32xf32>],
device = "/job:localhost/replica:0/task:0/device:TPU:0"}
: (tensor<32xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<32xf32>)
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[ARG_1]])
%execute:2 = "tf_device.launch"() ( {
%0:2 = "tf.TPUExecute"(%read0, %arg1) {
Targs = [tensor<32xf32>], Tresults = [tensor<32xf32>, tensor<32xf32>]}
: (tensor<32xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<32xf32>)
tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<32xf32>
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<32xf32>)
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0)
"tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#1)

View File

@ -31,7 +31,11 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
// CHECK-SAME: as %[[V1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
// CHECK: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
return
}
// CHECK: func @while_body_7560
@ -51,27 +55,41 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// CHECK-SAME: %[[STATE_ARG1:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:1"})
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
%2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
"tf.TPUCompileSucceededAssert"(%2#0) : (tensor<!tf.string>) -> ()
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
%compile:2 = "tf_device.launch"() ( {
%2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[BODY_ARG1]], %[[BODY_ARG2]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
// CHECK-SAME: [%[[BODY_ARG3]], %[[BODY_ARG4]]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
// CHECK-SAME: [%[[STATE_ARG0]], %[[STATE_ARG1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
%rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
// CHECK: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %2#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
tf_device.return %ret : tensor<i32>
}
@ -131,12 +149,18 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
"tf.TPUCompileSucceededAssert"(%2#0) : (tensor<!tf.string>) -> ()
%compile:2 = "tf_device.launch"() ( {
%2:2 = "tf._TPUCompileMlir"() {
NumDynamicShapes = 0 : i64,
// The metadata encodes 2 parameter and two return values.
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<!tf.string>
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
"tf_device.launch"() ( {
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
tf_device.return
}) {device = "/device:CPU:0"} : () -> ()
%id0 = "tf.Identity"(%arg3) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
"tf._Unknown_"(%id0) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> ()
%newvar = "tf._SomeOp"() : () -> tensor<*x!tf.resource<tensor<f32>>>
@ -146,10 +170,13 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
// %arg30 is used in the cond function, %arg31 has other uses (%id0), and
// %arg32 is not a pass-through.
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %2#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<!tf.string>) -> ()
"tf_device.launch"() ( {
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1)
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
tensor<*x!tf.resource<tensor<f32>>>, tensor<!tf.string>) -> ()
tf_device.return
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
tf_device.return
}
return %1, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,

View File

@ -431,16 +431,21 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK-SAME: device = "/job:worker/replica:0/task:0/device:TPU:0"
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
@ -472,15 +477,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-SAME: n = 2
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1)
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf_device.launch_func"(%ri_0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: tf_device.return %[[EXECUTE_OUTPUT]]
@ -539,7 +549,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
@ -548,7 +559,13 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-SAME: func @nested_func
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
@ -582,7 +599,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
@ -591,7 +609,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-SAME: func @referenced_func
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
@ -624,7 +645,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
@ -635,7 +657,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-SAME: @referenced_func2
// CHECK-SAME: tf.E
// CHECK-NOT: func = @tpu0_func
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
@ -674,7 +699,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
@ -684,7 +710,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-COUNT-1: func @referenced_func
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
@ -718,25 +747,33 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func0
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE0_OUTPUT]]#0)
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
%2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func1
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE1_OUTPUT]]#0)
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
%3 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE1_OUTPUT]])
@ -768,25 +805,33 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE0_OUTPUT]]#0)
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
%2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-NOT: func = @tpu0_func
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE1_OUTPUT]]#0)
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
%3 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE1_OUTPUT]])
@ -814,7 +859,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
// CHECK-SAME: NumDynamicShapes = 1
// CHECK-SAME: metadata
// CHECK-SAME: mlir_module
@ -828,7 +874,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-SAME: tf.G
// CHECK-SAME: func @referenced_func0
// CHECK-SAME: tf.F
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
@ -873,8 +922,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-LABEL: func @tpu_compilation_result
func @tpu_compilation_result(%arg0: tensor<?xi32>) -> (tensor<?xi32>, tensor<!tf.string>, tensor<!tf.string>) {
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecute"
%1 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<?xi32>) -> tensor<?xi32>
%compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>

View File

@ -0,0 +1,147 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-sharding-identification | FileCheck %s --dump-input=fail
// Tests empty launch func. Empty input/output sharding configuration
// attributes must be added.
// CHECK-LABEL: func @check_sharding_attrs_exists_for_empty_launch_func
func @check_sharding_attrs_exists_for_empty_launch_func() {
"tf_device.launch_func"() {device = "", func = @empty_func, step_marker_location = ""} : () -> ()
// CHECK: input_sharding_configuration = []
// CHECK: output_sharding_configuration = []
return
}
func @empty_func() {
return
}
// -----
// Tests with a inputs/outputs with no xla sharding op attached gets
// default maximal(0) sharding configuration.
// CHECK-LABEL: func @check_default_sharding_for_inputs_outputs
func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) {
"tf_device.launch_func"(%arg0) {device = "", func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> ()
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
return
}
func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.A"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
// -----
// Tests with a input arg connected to XlaSharding op.
// CHECK-LABEL: func @check_sharding_for_input_correctly_identified
func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) {
"tf_device.launch_func"(%arg0) {device = "", func = @inputs_with_sharding_func, step_marker_location = ""} : (tensor<*xi32>) -> ()
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
return
}
func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
return %1 : tensor<*xi32>
}
// -----
// Tests with sharding is correctly parsed for multiple inputs/outputs.
// CHECK-LABEL: func @check_sharding_for_multiple_inputs_outputs
func @check_sharding_for_multiple_inputs_outputs(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
return
}
func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%2, %3 = "tf.A"(%0, %1) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
%4 = "tf.XlaSharding"(%2) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %4, %5 : tensor<*xi32> , tensor<*xi1>
}
// -----
// Tests with input sharding following an identity op.
// CHECK-LABEL: func @check_sharding_after_identity
func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_identity, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
return
}
func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
%2 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%3, %4 = "tf.A"(%1, %2) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %5, %6 : tensor<*xi32> , tensor<*xi1>
}
// -----
// Tests with input sharding following a ReadVariable op.
// CHECK-LABEL: func @check_sharding_after_read_variable
func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_read_variable, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
return
}
func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>, %arg1: tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32>
%2 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%3 = "tf.Identity"(%2) : (tensor<32xf32>) -> tensor<32xf32>
%4 = "tf.XlaSharding"(%3) { _XlaSharding = "\04\05\06" } : (tensor<32xf32>) -> tensor<32xf32>
%5, %6 = "tf.A"(%1, %3) : (tensor<32xf32>, tensor<32xf32>) -> (tensor<*xi32>, tensor<*xi1>)
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%8 = "tf.XlaSharding"(%6) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %7, %8 : tensor<*xi32> , tensor<*xi1>
}
// -----
// Tests with input sharding following an identity op and cast op.
// CHECK-LABEL: func @check_sharding_after_cast_op
func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_cast, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
// CHECK: input_sharding_configuration
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
// CHECK: output_sharding_configuration
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
return
}
func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
%1 = "tf.Cast"(%0) : (tensor<*xi32>) -> tensor<*xi1>
%2 = "tf.XlaSharding"(%1) { _XlaSharding = "\01\02\03" } : (tensor<*xi1>) -> tensor<*xi1>
%3 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
%4, %5 = "tf.A"(%2, %3) : (tensor<*xi1>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
return %6, %7 : tensor<*xi32> , tensor<*xi1>
}

View File

@ -1,23 +1,23 @@
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-unroll-batch-matmul %s | FileCheck %s
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-unroll-batch-matmul %s | FileCheck %s
func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulV2TwoDim
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
@ -67,16 +67,16 @@ func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>)
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulV2FlatInput
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>}
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
@ -122,19 +122,19 @@ func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>)
return %0 : tensor<2x3x4x6xf32>
// CHECK-LABEL: batchMatMulTwoDim
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
@ -184,16 +184,16 @@ func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -
return %0 : tensor<3x4x6xf32>
// CHECK-LABEL: batchMatMulFlatInput
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>}
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>

View File

@ -25,6 +25,17 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
namespace mlir {
namespace {
// Add logger to bridge passmanager.
void EnableLogging(PassManager *pm) {
// Print the whole module after each pass, which requires disabling
// multi-threading as well.
pm->disableMultithreading();
pm->enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
/*print_module_scope=*/true));
}
} // namespace
namespace TFTPU {
namespace {
void AddGraphExportLoweringPasses(OpPassManager &pm) {
@ -32,16 +43,14 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) {
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateToIslandPass());
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateLaunchToDeviceAttributePass());
}
tensorflow::Status RunTPUBridge(
ModuleOp module, bool enable_logging,
llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
PassManager bridge(module.getContext());
// Add logger to bridge passmanager.
if (enable_logging)
bridge.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>());
if (enable_logging) EnableLogging(&bridge);
// Populate a passmanager with the list of passes that implement the bridge.
pipeline_builder(bridge);
@ -117,10 +126,7 @@ tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
bool enable_logging,
bool enable_inliner) {
PassManager bridge(module.getContext());
// Add logger to bridge passmanager.
if (enable_logging)
bridge.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>());
if (enable_logging) EnableLogging(&bridge);
StandardPipelineOptions pipeline_options;
pipeline_options.enable_inliner.setValue(enable_inliner);

View File

@ -43,6 +43,9 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializePassthroughOpPass();
// Performs Shape Inference on the TensorFlow dialect using the global registry.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTFShapeInferencePass();
// Optional pass which will unroll BatchMatMul and use only MatMul
std::unique_ptr<OpPassBase<FuncOp>> CreateUnrollBatchMatMulPassPass();
// Optimizes Tensorflow graph.
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass();
@ -99,6 +102,11 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
// Performs resource lifting on the function body to hoist resource variable
// accesses outside all control flow statements.
LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function);
// Converts stack ops into operations on local variables, which can later be
// removed by resource lifting. Requires known maximum sizes of stacks and
// known element shapes of push ops.
std::unique_ptr<OpPassBase<ModuleOp>> CreateStackOpsDecompositionPass();
} // namespace TF
namespace TFControlFlow {
@ -209,6 +217,10 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
// ops.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass();
// Creates a pass that identifies XLASharding ops in launch op for TPU
// computation.
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass();
// Creates a pass that merges device variable reads/updates into the surrounded
// TPUExecute node. This allows the execute node to perform in-place variable
// updates.

View File

@ -47,22 +47,10 @@ struct ReplicateToIslandPass : public FunctionPass<ReplicateToIslandPass> {
void runOnFunction() override;
};
// Get the device name for which replica_index-th block will execute.
llvm::StringRef GetDeviceNameFromAttribute(DictionaryAttr devices,
int replica_index) {
// TODO(b/148913020): Remove this constraint once model parallelism is
// supported.
DCHECK_EQ(devices.size(), 1);
Attribute device_attr = devices.begin()->second;
return device_attr.cast<ArrayAttr>()
.getValue()[replica_index]
.cast<StringAttr>()
.getValue();
}
// Creates islands per replica from `tf_device.replicate` region. TensorFlow ops
// will have their device set to the replica if they originally did not have a
// device assigned.
// Creates islands per replica from `tf_device.replicate` region. If for a
// `tf_device.launch` op the device is an aliased device of the
// `tf_device.replicate`, the device will be remapped to an explicit device
// for the associated replica island.
llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
const Dialect* tf_dialect, OpBuilder* builder,
tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op,
@ -87,10 +75,6 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
builder->setInsertionPoint(island_op);
BlockAndValueMapping mapping;
for (int i : llvm::seq<int>(0, num_replicas)) {
// Determine optional device.
llvm::StringRef device =
has_devices ? GetDeviceNameFromAttribute(devices.getValue(), i) : "";
// Create new island for replica.
auto replica = builder->create<tf_executor::IslandOp>(
island_op.getLoc(), output_types, control_type, replica_inputs);
@ -104,13 +88,13 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
// Copy over replicate region into replica island.
replicate_op.body().cloneInto(&replica.body(), mapping);
// Assign all TF ops in island optional device, if device is set.
if (!device.empty()) {
StringAttr device_attr = builder->getStringAttr(device);
replica.walk([&](Operation* op) {
if (op->getDialect() != tf_dialect) return;
if (!op->getAttr(kDeviceAttr)) op->setAttr(kDeviceAttr, device_attr);
// Map aliased devices to explicit devices based on replica.
if (has_devices) {
replica.walk([&](tf_device::LaunchOp launch) {
if (auto device_by_replica = devices.getValue().get(launch.device()))
launch.setAttr(
kDeviceAttr,
device_by_replica.cast<ArrayAttr>()[i].cast<StringAttr>());
});
}
@ -124,16 +108,25 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
// replicate results with new island outputs. A single island is created to
// forward results from each replica island. Control dependencies of individual
// replicas are added to the single island if the single island does not emit
// a result from the respective replica.
// a result from the respective replica. Devices are remapped from aliased
// devices to explicit devices, for `tf_device.launch` ops.
//
// For example, the following:
//
// %0:2 = tf_executor.island(%control) {
// %1:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
// {n = 2 : i32, devices = ["/CPU:0", "/GPU:1"]} {
// %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
// %3 = "tf.opB"(%2) : (tensor<i1>) -> tensor<i1>
// tf_device.return %2, %3 : tensor<i1>, tensor<i1>
// {n = 2 : i32,
// devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
// DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
// %a = "tf_device.launch"() ( {
// %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
// tf_device.return %2 : tensor<i1>
// }) {device = "DEVICE_ALIAS_0"} : () -> tensor<i1>
// %b = "tf_device.launch"() ( {
// %3 = "tf.opB"(%a) : (tensor<i1>) -> tensor<i1>
// tf_device.return %3 : tensor<i1>
// }) {device = "DEVICE_ALIAS_1"} : () -> tensor<i1>
// tf_device.return %a, %b : tensor<i1>, tensor<i1>
// }
// tf_executor.yield %1#0 : tensor<i1>
// }
@ -141,14 +134,26 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
// gets lowered to:
//
// %0:3 = tf_executor.island(%control) {
// %1 = "tf.opA"(%arg0) {device = "/CPU:0"} : (tensor<i1>) -> tensor<i1>
// %2 = "tf.opB"(%1) {device = "/CPU:0"} : (tensor<i1>) -> tensor<i1>
// tf_executor.yield %1, %2 : tensor<i1>, tensor<i1>
// %a0 = "tf_device.launch"() ( {
// %1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
// tf_device.return %1 : tensor<i1>
// }) {device = "/DEVICE:0"} : () -> tensor<i1>
// %b0 = "tf_device.launch"() ( {
// %2 = "tf.opB"(%a0) : (tensor<i1>) -> tensor<i1>
// tf_device.return %2 : tensor<i1>
// }) {device = "/DEVICE:2"} : () -> tensor<i1>
// tf_executor.yield %a0, %b0 : tensor<i1>, tensor<i1>
// }
// %3:3 = tf_executor.island(%control) {
// %4 = "tf.opA"(%arg1) {device = "/GPU:1"} : (tensor<i1>) -> tensor<i1>
// %5 = "tf.opB"(%4) {device = "/GPU:1"} : (tensor<i1>) -> tensor<i1>
// tf_executor.yield %4, %5 : tensor<i1>, tensor<i1>
// %a1 = "tf_device.launch"() ( {
// %4 = "tf.opA"(%arg1) : (tensor<i1>) -> tensor<i1>
// tf_device.return %4 : tensor<i1>
// }) {device = "/DEVICE:1"} : () -> tensor<i1>
// %b1 = "tf_device.launch"() ( {
// %5 = "tf.opB"(%a1) : (tensor<i1>) -> tensor<i1>
// tf_device.return %5 : tensor<i1>
// }) {device = "/DEVICE:3"} : () -> tensor<i1>
// tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
// }
// %6:2 = tf_executor.island(%3#2) {
// tf_executor.yield %0#0 : tensor<i1>

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
@ -365,9 +366,10 @@ LogicalResult FindResourceArgUseInfo(
auto return_op = func_op.front().getTerminator();
for (auto arg : func_op.getArguments()) {
if (!getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) continue;
auto& info = (*result)[arg.getArgNumber()];
ResourceArgUseInfo info;
info.used = false;
info.updated = false;
bool do_not_touch = false;
for (auto user : arg.getUsers()) {
if (user == return_op) continue;
if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
@ -381,9 +383,16 @@ LogicalResult FindResourceArgUseInfo(
info.data_type = assign.value().getType();
continue;
}
user->emitError("Found unsupported operations on resource.");
if (llvm::isa<TF::StackPushV2Op>(user) ||
llvm::isa<TF::StackPopV2Op>(user)) {
// Stacks will be handled by a separate pass.
do_not_touch = true;
break;
}
user->emitOpError("found unsupported operations on resource.");
return failure();
}
if (!do_not_touch) (*result)[arg.getArgNumber()] = info;
}
return success();
}
@ -395,14 +404,18 @@ LogicalResult FindResourceArgUseInfo(
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo(
const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0,
const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) {
auto result = infos0;
for (auto& entry : result) {
if (entry.getSecond().used) continue;
auto& info1_entry = *infos1.find(entry.getFirst());
if (info1_entry.getSecond().used) {
entry.getSecond().used = true;
entry.getSecond().updated |= info1_entry.getSecond().updated;
entry.getSecond().data_type = info1_entry.getSecond().data_type;
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> result;
for (const auto& entry : infos0) {
auto info1_it = infos1.find(entry.getFirst());
// If the entry is missing in any input, we should not touch this entry.
if (info1_it == infos1.end()) continue;
auto& info = result[entry.getFirst()];
info = entry.getSecond();
if (info.updated) continue;
if (info1_it->getSecond().used) {
info.used = true;
info.updated = info1_it->getSecond().updated;
info.data_type = info1_it->getSecond().data_type;
}
}
return result;
@ -576,16 +589,16 @@ LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
for (auto arg : body.getArguments()) {
if (!getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) continue;
if (return_op->getOperand(arg.getArgNumber()) != arg) {
return return_op->emitError(
"Resource used in while loop is only supported when the resource "
"input and output alias each other in the loop body.");
return return_op->emitOpError(
"resource used in while loop is only supported when the ")
<< "resource input and output alias each other in the loop body.";
}
}
// FindResourceArgUseInfo will check supported resource ops (read and assign),
// but loop condition has additional requirement that it cannot write
// resources.
if (cond.walk([&](TF::AssignVariableOp assign) {
assign.emitError("Found resource write in loop condition.");
assign.emitOpError("found resource write in loop condition.");
return WalkResult::interrupt();
})
.wasInterrupted()) {
@ -668,7 +681,7 @@ LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
}
// Lifts loads/stores from an IfOp's branches.
LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
FuncOp else_branch) {
// Remove identity nodes to avoid aliasing.
RemoveIdentity(&then_branch.front());
@ -689,9 +702,8 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
auto else_aliasing_arg = else_retval.dyn_cast<BlockArgument>();
if (!then_aliasing_arg || !else_aliasing_arg ||
then_aliasing_arg.getArgNumber() != else_aliasing_arg.getArgNumber()) {
return if_op.emitOpError(
"Unsupported tf.IfOp output: resource does not alias a single "
"input.");
return if_op.emitOpError("unsupported tf.IfOp output: ")
<< "resource does not alias a single input.";
}
if_op.getResult(entry.index())
.replaceAllUsesWith(
@ -701,6 +713,7 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
int64_t non_resource_results = 0;
llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
llvm::SmallVector<Attribute, 4> new_output_shapes;
bool output_removed = false;
for (auto result : if_op.getResults()) {
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
old_to_new_output_indices.push_back(non_resource_results++);
@ -713,6 +726,7 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
old_to_new_output_indices.push_back(-1);
then_branch.front().getTerminator()->eraseOperand(non_resource_results);
else_branch.front().getTerminator()->eraseOperand(non_resource_results);
output_removed = true;
}
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> then_use_info;
@ -724,7 +738,7 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
// A resource is considered used as long as it is used in either branch.
auto resource_arg_uses =
MergeArgResourceUseInfo(then_use_info, else_use_info);
if (resource_arg_uses.empty()) return success();
if (resource_arg_uses.empty() && !output_removed) return success();
// Remove unused resources in functions.
llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
RemoveUnusedResourceArgumentsAndForwardedRetvals(
@ -847,9 +861,8 @@ LogicalResult HandlePartitionedCallOpCallee(
}
auto aliasing_arg = retval.dyn_cast<BlockArgument>();
if (!aliasing_arg) {
return callee.emitOpError(
"Unsupported function call: resource return value does not alias an "
"input.");
return callee.emitOpError("unsupported function call: ")
<< "resource return value does not alias an input.";
}
result->old_outputs_aliasing_old_inputs[entry.index()] =
aliasing_arg.getArgNumber();
@ -982,6 +995,8 @@ LogicalResult HoistForFunctionalControlFlow(
Block* block, ModuleOp module,
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*
lifted_partitioned_call_callees) {
// Remove identity nodes to avoid aliasing.
RemoveIdentity(block);
for (Operation& op : llvm::make_early_inc_range(*block)) {
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
@ -1002,11 +1017,11 @@ LogicalResult HoistForFunctionalControlFlow(
lifted_partitioned_call_callees);
HoistForFunctionalControlFlow(&else_branch.front(), module,
lifted_partitioned_call_callees);
if (failed(HanldeIfOP(if_op, then_branch, else_branch))) return failure();
if (failed(HandleIfOP(if_op, then_branch, else_branch))) return failure();
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
if (!call_op.f().isa<FlatSymbolRefAttr>()) {
return call_op.emitError(
"Resource lifting does not support call with nested references.");
return call_op.emitOpError(
"resource lifting does not support call with nested references.");
}
auto callee = llvm::cast<FuncOp>(
module.lookupSymbol(call_op.f().getRootReference()));
@ -1024,6 +1039,24 @@ LogicalResult HoistForFunctionalControlFlow(
}
}
}
// Remove unused local variables.
ForwardStoreToLoad(block);
llvm::SmallVector<TF::MlirLocalVarOp, 8> local_vars;
for (Operation& op : *block) {
if (auto local_var = llvm::dyn_cast<TF::MlirLocalVarOp>(&op)) {
local_vars.push_back(local_var);
}
}
for (auto local_var : local_vars) {
if (llvm::all_of(local_var.resource().getUsers(),
[](const Operation* user) {
return llvm::isa<TF::AssignVariableOp>(user);
})) {
for (auto user : local_var.resource().getUsers()) user->erase();
local_var.erase();
}
}
return success();
}

View File

@ -0,0 +1,742 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/types.h"
namespace mlir {
namespace {
// A pass that converts stack operations to tensor operations and read/assign
// ops on local variables. A later resource lifting pass can further remove the
// local variables.
//
// This pass requires that the full shape of the stack can be inferred: 1) the
// maximum size needs to be a constant and 2) a push op can be found with a
// known shape, and all push ops need to have the same shape.
//
// A stack creation op "tf.StackV2" will be turned in to two zero-initialized
// variables, for the buffer and current size. Each push will be turned into
// %old_val = "tf.ReadVariableOp"(%buffer)
// %old_size = "tf.ReadVariableOp"(%size)
// %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
// %new_val = "tf.XlaDynamicUpdateSlice"(%old_val, %push_val, %offsets)
// "tf.AssignVariableOp"(%buffer, %new_val)
// %new_size = "tf.AddV2"(%old_size, %const1)
// "tf.AssignVariableOp"(%size, %new_size)
//
// and each pop will be turned into
//
// %old_val = "tf.ReadVariableOp"(%buffer)
// %old_size = "tf.ReadVariableOp"(%size)
// %new_size = "tf.Sub"(%old_size, %const1)
// %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
// %slice = "tf.Slice"(%old_val, %offsets, %slice_size_const)
// %pop_result = "tf.Reshape"(%slice, %elem_size_const)
// "tf.AssignVariableOp"(%size, %new_size)
//
// The pass also works across control flow and functional calls.
struct StackOpsDecompositionPass
: public ModulePass<StackOpsDecompositionPass> {
void runOnModule() override;
};
// Creates a ReadVariableOp on a local variable.
Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) {
return builder
.create<TF::ReadVariableOp>(
loc,
ArrayRef<Type>{getElementTypeOrSelf(local_var.getType())
.cast<TF::ResourceType>()
.getSubtypes()[0]},
ArrayRef<Value>{local_var}, ArrayRef<NamedAttribute>{})
.value();
}
// Creates an AssignVariableOp on a local variable.
TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value,
OpBuilder builder, Location loc) {
return builder.create<TF::AssignVariableOp>(loc, ArrayRef<Type>{},
ArrayRef<Value>{local_var, value},
ArrayRef<NamedAttribute>{});
}
// Creates an i32 scalar tf.Const.
TF::ConstOp CreateScalarConst(int value, OpBuilder builder, Location loc) {
tensorflow::Tensor scalar_tensor(tensorflow::DT_INT32, {});
scalar_tensor.scalar<tensorflow::int32>()() = value;
return builder.create<TF::ConstOp>(
loc, tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
}
// Creates an i32 vector tf.Const.
TF::ConstOp GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc) {
tensorflow::Tensor shape_tensor(tensorflow::DT_INT32,
{static_cast<int64_t>(r1.size())});
for (int i = 0; i < r1.size(); ++i) {
shape_tensor.vec<tensorflow::int32>()(i) = r1[i];
}
return builder.create<TF::ConstOp>(
loc, tensorflow::ConvertTensor(shape_tensor, &builder).ValueOrDie());
}
// Creates a rank-1 op that represents the offsets of the stack element in the
// stack buffer.
Value GetIndicesForStackElement(Value index, Value stack_value,
OpBuilder builder, Location loc) {
auto stack_type = stack_value.getType().cast<RankedTensorType>();
if (stack_type.getShape().size() == 1) return index;
llvm::SmallVector<int64_t, 8> zeros(stack_type.getShape().size() - 1, 0);
auto zeros_tensor = GetR1Const(zeros, builder, loc);
return builder.create<TF::ConcatV2Op>(
loc,
ArrayRef<Type>{RankedTensorType::get(
{static_cast<int64_t>(stack_type.getShape().size())},
getElementTypeOrSelf(index.getType()))},
ArrayRef<Value>{index, zeros_tensor, CreateScalarConst(0, builder, loc)},
ArrayRef<NamedAttribute>{});
}
// Returns the type of the local variable for the stack size. It is a
// tensor<1xi32>, and we use R1 instead of a scalar because it is easier to
// concat it with other offsets.
Type GetSizeVarType(OpBuilder builder) {
auto size_type = RankedTensorType::get({1}, builder.getIntegerType(32));
return RankedTensorType::get(
{}, TF::ResourceType::get(ArrayRef<TensorType>{size_type},
builder.getContext()));
}
// Creates the buffer and size local variables for a stack.
std::pair<Value, Value> CreateVariablesForStack(TensorType stack_tensor_type,
TF::StackV2Op stack) {
OpBuilder builder(stack);
auto size_var_type = GetSizeVarType(builder);
auto var_type = RankedTensorType::get(
{}, TF::ResourceType::get(ArrayRef<TensorType>{stack_tensor_type},
stack.getContext()));
auto local_var = builder.create<TF::MlirLocalVarOp>(
stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
auto local_size_var = builder.create<TF::MlirLocalVarOp>(
stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
// Zero-initialize the local vars.
WriteLocalVariable(local_size_var, GetR1Const({0LL}, builder, stack.getLoc()),
builder, stack.getLoc());
auto zero = CreateScalarConst(0, builder, stack.getLoc()).output();
if (getElementTypeOrSelf(zero.getType()) !=
stack_tensor_type.getElementType()) {
zero = builder.create<TF::CastOp>(
stack.getLoc(),
ArrayRef<Type>{
RankedTensorType::get({}, stack_tensor_type.getElementType())},
ArrayRef<Value>{zero}, ArrayRef<NamedAttribute>{});
}
auto broadcast = builder.create<TF::BroadcastToOp>(
stack.getLoc(), ArrayRef<Type>{stack_tensor_type},
ArrayRef<Value>{zero, GetR1Const(stack_tensor_type.getShape(), builder,
stack.getLoc())},
ArrayRef<NamedAttribute>{});
WriteLocalVariable(local_var, broadcast, builder, stack.getLoc());
return {local_var, local_size_var};
}
// Tries to infer the stack element type with full shape based on its uses.
llvm::Optional<RankedTensorType> GetStackElementType(Value stack,
ModuleOp module) {
for (auto& use : stack.getUses()) {
if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(use.getOwner())) {
auto elem_type = push.elem().getType().dyn_cast<RankedTensorType>();
if (elem_type && elem_type.hasStaticShape()) {
return elem_type;
}
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
auto body = module.lookupSymbol<FuncOp>(while_op.body());
assert(body);
auto type_from_body =
GetStackElementType(body.getArgument(use.getOperandNumber()), module);
if (type_from_body.hasValue()) return type_from_body;
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(use.getOwner())) {
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
assert(then_branch && else_branch);
auto type_from_then = GetStackElementType(
then_branch.getArgument(use.getOperandNumber() - 1), module);
if (type_from_then.hasValue()) return type_from_then;
auto type_from_else = GetStackElementType(
else_branch.getArgument(use.getOperandNumber() - 1), module);
if (type_from_else.hasValue()) return type_from_else;
} else if (auto pcall =
llvm::dyn_cast<TF::PartitionedCallOp>(use.getOwner())) {
if (!pcall.f().isa<FlatSymbolRefAttr>()) continue;
auto callee = module.lookupSymbol<FuncOp>(pcall.f().getRootReference());
assert(callee);
auto type_from_callee = GetStackElementType(
callee.getArgument(use.getOperandNumber()), module);
if (type_from_callee.hasValue()) return type_from_callee;
} else if (auto spcall = llvm::dyn_cast<TF::StatefulPartitionedCallOp>(
use.getOwner())) {
auto callee = module.lookupSymbol<FuncOp>(spcall.f());
assert(callee);
auto type_from_callee = GetStackElementType(
callee.getArgument(use.getOperandNumber()), module);
if (type_from_callee.hasValue()) return type_from_callee;
} else if (llvm::isa<TF::IdentityOp>(use.getOwner()) ||
llvm::isa<TF::IdentityNOp>(use.getOwner())) {
auto type_from_alias = GetStackElementType(
use.getOwner()->getResult(use.getOperandNumber()), module);
if (type_from_alias.hasValue()) return type_from_alias;
}
}
return llvm::None;
}
// Returns the aliasing argument number of a fucntion return value if it simply
// forwards the argument. Otherwise, returns -1.
int64_t FindAliasedInput(FuncOp func, int64_t return_index) {
Value return_val = func.front().getTerminator()->getOperand(return_index);
auto maybe_arg = return_val.dyn_cast<BlockArgument>();
if (!maybe_arg) return -1;
return maybe_arg.getArgNumber();
}
// Changes the function signature that has stacks in the arguments. A stack
// argument will be turned into a variable type if arg_to_stack_type returns
// such a type, and a new argument will be added to the end of the argument
// list for the size variable.
//
// If stack_var_to_size_var is not nullptr, it will be used to store the
// mapping from the stack-variable argument to the size-variable argument.
//
// If handle_new_size_vars is provided, it will be invoked on the list of new
// size variables before finally changing the function type.
void ModifyFunctionSignature(
FuncOp func, llvm::SmallDenseMap<Value, Value>* stack_var_to_size_var,
llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_stack_type,
llvm::function_ref<void(ArrayRef<BlockArgument>)> handle_new_size_vars =
nullptr) {
auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
auto size_var_type = GetSizeVarType(OpBuilder(func));
int64_t original_arg_count = new_input_types.size();
for (int64_t i = 0; i < original_arg_count; ++i) {
auto stack_type = arg_to_stack_type(i);
if (!stack_type.hasValue()) continue;
func.getArgument(i).setType(*stack_type);
new_input_types[i] = *stack_type;
auto size_arg = func.front().addArgument(size_var_type);
new_input_types.push_back(size_arg.getType());
if (stack_var_to_size_var) {
(*stack_var_to_size_var)[func.getArgument(i)] = size_arg;
}
}
if (handle_new_size_vars) {
handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
}
func.setType(FunctionType::get(
new_input_types,
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
func.getContext()));
}
// Contains cached information for decomposed callee functions for (stateful)
// partitioned call ops.
struct PartitionedCallStackOpsInfo {
bool signature_change;
FuncOp decomposed_callee;
llvm::SmallDenseMap<int64_t, int64_t> stack_var_arg_to_size_arg;
};
LogicalResult DecomposeStackOpsInternal(
Block*, ModuleOp, llvm::SmallDenseMap<Value, Value>*,
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*);
// Handles stack usage by a tf.While. It will convert the body and conditional
// function signatures, and performs stack ops decomposition on them.
LogicalResult HandleWhileOp(
TF::WhileOp while_op, ModuleOp module,
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*
decomposed_partitioned_call_callees) {
auto body = module.lookupSymbol<FuncOp>(while_op.body());
llvm::SmallDenseMap<Value, Value> body_map;
auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
auto it = data_var_to_size_var.find(while_op.getOperand(index));
if (it == data_var_to_size_var.end()) return llvm::None;
return it->getFirst().getType();
};
auto add_size_vars_to_return = [&](ArrayRef<BlockArgument> new_args) {
if (new_args.empty()) return;
auto body_ret = body.front().getTerminator();
auto new_body_returns = llvm::to_vector<8>(body_ret->getOperands());
for (auto arg : new_args) new_body_returns.push_back(arg);
OpBuilder(body_ret).create<ReturnOp>(body_ret->getLoc(), new_body_returns);
body_ret->erase();
};
// Handle body.
ModifyFunctionSignature(body, &body_map, find_arg_stack_type,
add_size_vars_to_return);
const bool signature_change = !body_map.empty();
if (failed(DecomposeStackOpsInternal(&body.front(), module, &body_map,
decomposed_partitioned_call_callees))) {
return failure();
}
// Cond should not change stacks in the arguments, so use an empty map.
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
llvm::SmallDenseMap<Value, Value> empty_map;
if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
decomposed_partitioned_call_callees))) {
return failure();
}
if (!signature_change) return success();
// Create the new while op.
auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
auto new_output_shapes =
llvm::to_vector<8>(while_op.output_shapes().getValue());
OpBuilder builder(while_op);
assert(while_op.getNumOperands() == while_op.getNumResults());
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
auto it = data_var_to_size_var.find(while_op.getOperand(i));
if (it == data_var_to_size_var.end()) continue;
new_while_operands.push_back(it->getSecond());
if (!new_output_shapes.empty()) {
// Size is a scalar shape.
tensorflow::TensorShapeProto shape_proto;
new_output_shapes.push_back(builder.getStringAttr(
tensorflow::mangling_util::MangleShape(shape_proto)));
}
}
auto new_while =
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
new_while_operands, while_op.getAttrs());
new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
.isa<TF::ResourceType>()) {
continue;
}
int64_t aliased_input = FindAliasedInput(body, i);
if (aliased_input == i) {
// Replace aliased stack output uses with input.
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
}
}
while_op.replaceAllUsesWith(
new_while.getResults().take_front(while_op.getNumResults()));
while_op.erase();
return success();
}
// Handles stack usage by a tf.If. It will convert the branch function
// signatures, and performs stack ops decomposition on them.
LogicalResult HandleIfOp(
TF::IfOp if_op, ModuleOp module,
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*
decomposed_partitioned_call_callees) {
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
llvm::SmallDenseMap<Value, Value> then_map;
llvm::SmallDenseMap<Value, Value> else_map;
auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
auto it = data_var_to_size_var.find(if_op.getOperand(index + 1));
if (it == data_var_to_size_var.end()) return llvm::None;
return it->getFirst().getType();
};
ModifyFunctionSignature(then_branch, &then_map, find_arg_stack_type);
ModifyFunctionSignature(else_branch, &else_map, find_arg_stack_type);
const bool signature_change = !then_map.empty() || !else_map.empty();
if (failed(DecomposeStackOpsInternal(&then_branch.front(), module, &then_map,
decomposed_partitioned_call_callees)) ||
failed(DecomposeStackOpsInternal(&else_branch.front(), module, &else_map,
decomposed_partitioned_call_callees))) {
return failure();
}
if (!signature_change) return success();
auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
for (auto operand : if_op.getOperands()) {
auto it = data_var_to_size_var.find(operand);
if (it == data_var_to_size_var.end()) continue;
new_if_operands.push_back(it->getSecond());
}
auto new_if = OpBuilder(if_op).create<TF::IfOp>(
if_op.getLoc(), then_branch.getType().getResults(), new_if_operands,
if_op.getAttrs());
for (auto result : if_op.getResults()) {
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
continue;
}
int64_t then_aliased_input =
FindAliasedInput(then_branch, result.getResultNumber());
int64_t else_aliased_input =
FindAliasedInput(else_branch, result.getResultNumber());
if (then_aliased_input >= 0 && then_aliased_input == else_aliased_input) {
// Replace aliased stack output uses with input.
result.replaceAllUsesWith(if_op.getOperand(then_aliased_input + 1));
}
}
if_op.replaceAllUsesWith(new_if);
if_op.erase();
return success();
}
// Handles stack usage by a tf.StatefulPartitionedCall or a tf.PartitionedCall.
// It will first check if the callee was previously handled, and try to reuse
// that result if so. Otherwise, it will clone and convert the callee function,
// and performs stack ops decomposition on it.
template <typename CallOp>
LogicalResult HandlePartitionedCallOp(
CallOp call, FuncOp callee, ModuleOp module,
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*
decomposed_partitioned_call_callees) {
auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
callee, PartitionedCallStackOpsInfo());
auto& info = emplace_res.first->getSecond();
// Recreate the call op with info.
auto recreate_caller = [&] {
auto new_operands = llvm::to_vector<8>(call.getOperands());
for (int64_t i = 0; i < call.getNumOperands(); ++i) {
auto arg_it = info.stack_var_arg_to_size_arg.find(i);
if (arg_it == info.stack_var_arg_to_size_arg.end()) continue;
auto it = data_var_to_size_var.find(call.getOperand(i));
if (it == data_var_to_size_var.end()) {
call.emitOpError("Unknown stack.");
return failure();
}
assert(arg_it->second == new_operands.size());
new_operands.push_back(it->getSecond());
}
OpBuilder builder(call);
auto new_call = builder.create<CallOp>(
call.getLoc(), info.decomposed_callee.getType().getResults(),
new_operands, call.getAttrs());
new_call.setAttr(
"f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(info.decomposed_callee).getName()));
for (int64_t i = 0; i < call.getNumResults(); ++i) {
auto result = call.getResult(i);
if (!getElementTypeOrSelf(result.getType())
.template isa<TF::ResourceType>()) {
continue;
}
int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i);
if (aliased_input >= 0) {
// Replace aliased stack output uses with input.
result.replaceAllUsesWith(call.getOperand(aliased_input));
}
}
call.replaceAllUsesWith(new_call);
call.erase();
return success();
};
if (!emplace_res.second) {
// This callee was handled before.
if (!info.signature_change) return success();
return recreate_caller();
}
llvm::SmallDenseMap<Value, Value> callee_map;
auto callee_clone = callee.clone();
auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
auto it = data_var_to_size_var.find(call.getOperand(index));
if (it == data_var_to_size_var.end()) return llvm::None;
return it->getFirst().getType();
};
ModifyFunctionSignature(callee_clone, &callee_map, find_arg_stack_type);
if (callee_map.empty()) {
// Signature is not modified. We do not need the clone.
info.signature_change = false;
callee_clone.erase();
} else {
info.signature_change = true;
info.decomposed_callee = callee_clone;
for (auto& entry : callee_map) {
info.stack_var_arg_to_size_arg
[entry.getFirst().cast<BlockArgument>().getArgNumber()] =
entry.getSecond().cast<BlockArgument>().getArgNumber();
}
// Add the clone with a new name.
auto name_base = llvm::join(
std::vector<std::string>{callee.getName().str(), "stack_decomposed"},
"_");
auto name = name_base;
{
int64_t counter = 0;
while (module.lookupSymbol(name)) {
name = llvm::formatv("{0}_{1}", name_base, counter++).str();
}
}
callee_clone.setName(name);
SymbolTable(module).insert(callee_clone);
callee = callee_clone;
}
if (failed(DecomposeStackOpsInternal(&callee.front(), module, &callee_map,
decomposed_partitioned_call_callees))) {
return failure();
}
if (info.signature_change) return recreate_caller();
return success();
}
LogicalResult HandleStackV2Op(
TF::StackV2Op stack, ModuleOp module,
llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
// Create a buffer variable and a size variable to replace the stack.
auto elem_type = GetStackElementType(stack.handle(), module);
if (!elem_type.hasValue()) {
return stack.emitOpError("cannot infer element shape of stack.");
}
auto size_op = stack.max_size().getDefiningOp();
if (!size_op || !llvm::isa<TF::ConstOp>(size_op)) {
return stack.emitOpError("max size of stack is not a constant.");
}
int64_t max_size =
(*llvm::cast<TF::ConstOp>(size_op).value().getValues<APInt>().begin())
.getSExtValue();
llvm::SmallVector<int64_t, 8> stack_shape;
stack_shape.push_back(max_size);
for (int64_t dim : elem_type->getShape()) stack_shape.push_back(dim);
auto stack_tensor_type =
RankedTensorType::get(stack_shape, elem_type->getElementType());
Value local_var;
Value local_size_var;
std::tie(local_var, local_size_var) =
CreateVariablesForStack(stack_tensor_type, stack);
stack.replaceAllUsesWith(local_var);
(*data_var_to_size_var)[local_var] = local_size_var;
stack.erase();
return success();
}
LogicalResult HandleStackPushV2Op(
TF::StackPushV2Op push,
llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
auto it = data_var_to_size_var->find(push.handle());
if (it == data_var_to_size_var->end()) {
return push.emitOpError("unknown stack.");
}
// Push output simply forward the input element.
push.replaceAllUsesWith(push.elem());
OpBuilder builder(push);
// Read the current buffer and size.
auto stack_val = ReadLocalVariable(push.handle(), builder, push.getLoc());
auto index = ReadLocalVariable(it->getSecond(), builder, push.getLoc());
auto stack_buffer_type = stack_val.getType().cast<RankedTensorType>();
auto slice_shape = llvm::to_vector<8>(stack_buffer_type.getShape());
slice_shape[0] = 1;
// Caculate the updated buffer.
auto update_slice = builder.create<TF::ReshapeOp>(
push.getLoc(),
ArrayRef<Type>{RankedTensorType::get(slice_shape,
stack_buffer_type.getElementType())},
ArrayRef<Value>{push.elem(),
GetR1Const(slice_shape, builder, push.getLoc())},
ArrayRef<NamedAttribute>{});
stack_val =
builder
.create<TF::XlaDynamicUpdateSliceOp>(
push.getLoc(), ArrayRef<Type>{stack_val.getType()},
ArrayRef<Value>{stack_val, update_slice,
GetIndicesForStackElement(
index, stack_val, builder, push.getLoc())},
ArrayRef<NamedAttribute>{})
.output();
// Assign the new buffer and size.
WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
index = builder.create<TF::AddV2Op>(
push.getLoc(), ArrayRef<Type>{index.getType()},
ArrayRef<Value>{index, GetR1Const({1}, builder, push.getLoc())},
ArrayRef<NamedAttribute>{});
WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
push.erase();
return success();
}
LogicalResult HandleStackPopV2Op(
TF::StackPopV2Op pop,
llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
auto it = data_var_to_size_var->find(pop.handle());
if (it == data_var_to_size_var->end()) {
return pop.emitOpError("unknown stack.");
}
OpBuilder builder(pop);
// Read the current buffer and size.
auto stack_val = ReadLocalVariable(pop.handle(), builder, pop.getLoc());
auto size = ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
auto new_size = builder.create<TF::SubOp>(
pop.getLoc(), ArrayRef<Type>{size.getType()},
ArrayRef<Value>{size, GetR1Const({1}, builder, pop.getLoc())},
ArrayRef<NamedAttribute>{});
auto stack_val_type = stack_val.getType().cast<RankedTensorType>();
auto elem_type = RankedTensorType::get(stack_val_type.getShape().drop_front(),
stack_val_type.getElementType());
// Slice the buffer to get the element.
llvm::SmallVector<int64_t, 8> slice_size;
slice_size.push_back(1);
for (int64_t dim : elem_type.getShape()) slice_size.push_back(dim);
auto size_const = GetR1Const(slice_size, builder, pop.getLoc());
auto slice_type =
RankedTensorType::get(slice_size, stack_val_type.getElementType());
auto slice = builder.create<TF::SliceOp>(
pop.getLoc(), ArrayRef<Type>{slice_type},
ArrayRef<Value>{
stack_val,
GetIndicesForStackElement(new_size, stack_val, builder, pop.getLoc()),
size_const},
ArrayRef<NamedAttribute>{});
auto pop_val = builder.create<TF::ReshapeOp>(
pop.getLoc(), ArrayRef<Type>{elem_type},
ArrayRef<Value>{slice,
GetR1Const(elem_type.getShape(), builder, pop.getLoc())},
ArrayRef<NamedAttribute>{});
pop.replaceAllUsesWith(pop_val.output());
// Update the size.
WriteLocalVariable(it->getSecond(), new_size, builder, pop.getLoc());
pop.erase();
return success();
}
// Decomposes stack ops on a region and recursively decomposes called functions.
// data_var_to_size_var: a mapping from stacks' buffer local variables to size
// local variables.
// decomposed_partitioned_call_callees: cache for partitioned call ops' callee
// function handling.
LogicalResult DecomposeStackOpsInternal(
Block* block, ModuleOp module,
llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*
decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
// Removes identity nodes in the block. The device computation does not
// need such nodes to carry information.
op.replaceAllUsesWith(op.getOperands());
op.erase();
} else if (auto stack = llvm::dyn_cast<TF::StackV2Op>(&op)) {
if (failed(HandleStackV2Op(stack, module, data_var_to_size_var))) {
return failure();
}
} else if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(&op)) {
if (failed(HandleStackPushV2Op(push, data_var_to_size_var))) {
return failure();
}
} else if (auto pop = llvm::dyn_cast<TF::StackPopV2Op>(&op)) {
if (failed(HandleStackPopV2Op(pop, data_var_to_size_var))) {
return failure();
}
} else if (auto close = llvm::dyn_cast<TF::StackCloseV2Op>(&op)) {
data_var_to_size_var->erase(close.handle());
close.erase();
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
if (failed(HandleWhileOp(while_op, module, *data_var_to_size_var,
decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
if (failed(HandleIfOp(if_op, module, *data_var_to_size_var,
decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
if (!pcall.f().isa<FlatSymbolRefAttr>()) {
return pcall.emitOpError(
"Stack decomposition does not support call with nested "
"references.");
}
if (failed(HandlePartitionedCallOp(
pcall, module.lookupSymbol<FuncOp>(pcall.f().getRootReference()),
module, *data_var_to_size_var,
decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto spcall =
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
if (failed(HandlePartitionedCallOp(
spcall, module.lookupSymbol<FuncOp>(spcall.f()), module,
*data_var_to_size_var, decomposed_partitioned_call_callees))) {
return failure();
}
}
}
return success();
}
LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
llvm::SmallDenseMap<Value, Value> data_var_to_size_var;
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>
decomposed_partitioned_call_callees;
return DecomposeStackOpsInternal(block, module, &data_var_to_size_var,
&decomposed_partitioned_call_callees);
}
void StackOpsDecompositionPass::runOnModule() {
auto module = getModule();
auto main = module.lookupSymbol<FuncOp>("main");
if (!main) return;
if (failed(DecomposeStackOps(&main.front(), module))) {
signalPassFailure();
}
}
static PassRegistration<StackOpsDecompositionPass> pass(
"tf-stack-ops-decomposition",
"Decompose stack operations into local variable operations. Needs static "
"shapes.");
} // namespace
namespace TF {
std::unique_ptr<OpPassBase<ModuleOp>> CreateStackOpsDecompositionPass() {
return std::make_unique<StackOpsDecompositionPass>();
}
} // namespace TF
} // namespace mlir

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/STLExtras.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@ -91,14 +92,14 @@ bool IsSupportedInputOp(Operation* op) {
}
// Builds a TPUGetLayoutOp with the given compile op and input index.
TF::TPUGetLayoutOp BuildGetLayout(Operation* compile, int64_t index,
OpBuilder* builder) {
builder->setInsertionPointAfter(compile);
TF::TPUGetLayoutOp BuildGetLayout(tf_device::LaunchOp compile_launch,
int64_t index, OpBuilder* builder) {
builder->setInsertionPointAfter(compile_launch);
return builder->create<TF::TPUGetLayoutOp>(
compile->getLoc(),
compile_launch.getLoc(),
llvm::ArrayRef<Type>{
RankedTensorType::get({-1}, builder->getIntegerType(64))},
llvm::ArrayRef<Value>{compile->getResult(1)},
llvm::ArrayRef<Value>{compile_launch.getResult(1)},
llvm::ArrayRef<NamedAttribute>{
builder->getNamedAttr("index", builder->getI64IntegerAttr(index)),
builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
@ -109,11 +110,12 @@ TF::TPUGetLayoutOp BuildGetLayout(Operation* compile, int64_t index,
// ops after both get_layout and input, so we use the walk order to find which
// one comes later.
TF::TPUCopyWithLayoutOp BuildCopyWithLayout(
TF::TPUExecuteOp execute, Operation* compile, TF::TPUGetLayoutOp get_layout,
Value input, const llvm::SmallDenseMap<Operation*, int64_t>& walk_order,
TF::TPUExecuteOp execute, tf_device::LaunchOp compile_launch,
TF::TPUGetLayoutOp get_layout, Value input,
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order,
OpBuilder* builder) {
auto input_op = input.getDefiningOp();
int64_t compile_walk_order = walk_order.find(compile)->getSecond();
int64_t compile_walk_order = walk_order.find(compile_launch)->getSecond();
int64_t input_walk_order = walk_order.find(input_op)->getSecond();
if (compile_walk_order > input_walk_order) {
builder->setInsertionPointAfter(get_layout);
@ -128,22 +130,22 @@ TF::TPUCopyWithLayoutOp BuildCopyWithLayout(
// Performs transformation for a non-replicated input.
void HandleInput(Value input, int64_t index, TF::TPUExecuteOp execute,
Operation* compile,
tf_device::LaunchOp compile_launch,
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
OpBuilder builder(compile->getContext());
auto get_layout = BuildGetLayout(compile, index, &builder);
auto copy_with_layout = BuildCopyWithLayout(execute, compile, get_layout,
input, walk_order, &builder);
if (auto device = execute.getAttrOfType<StringAttr>(kDeviceAttr)) {
copy_with_layout.setAttr(kDeviceAttr, device);
}
OpBuilder builder(compile_launch.getContext());
auto get_layout = BuildGetLayout(compile_launch, index, &builder);
auto copy_with_layout = BuildCopyWithLayout(
execute, compile_launch, get_layout, input, walk_order, &builder);
copy_with_layout.setAttr(
kDeviceAttr,
llvm::cast<tf_device::LaunchOp>(execute.getParentOp()).deviceAttr());
execute.setOperand(index, copy_with_layout);
}
// Performs transformation for replicated inputs. Returns true if this is a
// supported case (thus transform happened).
bool HandleReplicatedInputs(
int64_t index, TF::TPUExecuteOp execute, Operation* compile,
int64_t index, TF::TPUExecuteOp execute, tf_device::LaunchOp compile_launch,
int64_t replicate_arg_index, tf_device::ReplicateOp replicate,
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
// We need to know the devices to copy to.
@ -157,10 +159,11 @@ bool HandleReplicatedInputs(
if (!input_op || !IsSupportedInputOp(input_op)) return false;
}
OpBuilder builder(execute.getContext());
auto get_layout = BuildGetLayout(compile, index, &builder);
auto get_layout = BuildGetLayout(compile_launch, index, &builder);
for (auto entry : llvm::enumerate(inputs)) {
auto copy_with_layout = BuildCopyWithLayout(
execute, compile, get_layout, entry.value(), walk_order, &builder);
auto copy_with_layout =
BuildCopyWithLayout(execute, compile_launch, get_layout, entry.value(),
walk_order, &builder);
// As model parallelism is not supported yet, assume that all ops are
// placed at logical core 0.
@ -179,7 +182,7 @@ bool HandleReplicatedInputs(
// Performs transformation on a pair of execute and compile ops. The compile
// should not have other uses.
void HandleExecute(TF::TPUExecuteOp execute, Operation* compile,
void HandleExecute(TF::TPUExecuteOp execute, tf_device::LaunchOp compile_launch,
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
auto maybe_replicate = execute.getParentOfType<tf_device::ReplicateOp>();
llvm::SmallVector<int64_t, 8> unrestricted_input_indices;
@ -188,7 +191,7 @@ void HandleExecute(TF::TPUExecuteOp execute, Operation* compile,
// For a block argument, consider transforms only when it is a replicated
// input (defining ops will be outside the replicate node).
if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
!HandleReplicatedInputs(input.index(), execute, compile,
!HandleReplicatedInputs(input.index(), execute, compile_launch,
block_arg.getArgNumber(), maybe_replicate,
walk_order)) {
continue;
@ -204,26 +207,28 @@ void HandleExecute(TF::TPUExecuteOp execute, Operation* compile,
continue;
}
if (!IsSupportedInputOp(input_op)) continue;
HandleInput(input.value(), input.index(), execute, compile, walk_order);
HandleInput(input.value(), input.index(), execute, compile_launch,
walk_order);
}
unrestricted_input_indices.push_back(input.index());
}
if (unrestricted_input_indices.empty()) return;
// Update the compilation metadata if we changed anything.
auto metadata_attr = compile->getAttrOfType<StringAttr>("metadata");
Operation& compile = compile_launch.GetBody().front();
auto metadata_attr = compile.getAttrOfType<StringAttr>("metadata");
assert(metadata_attr && "Missing compilation metadata");
tensorflow::tpu::TPUCompileMetadataProto metadata;
metadata.ParseFromString(std::string(metadata_attr.getValue()));
for (int64_t input_index : unrestricted_input_indices) {
metadata.mutable_args(input_index)->set_unrestricted_layout(true);
}
compile->setAttr("metadata", OpBuilder(compile).getStringAttr(
metadata.SerializeAsString()));
compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
compile.getContext()));
}
void TPUDynamicLayoutPass::runOnFunction() {
llvm::SmallVector<std::pair<TF::TPUExecuteOp, Operation*>, 4>
llvm::SmallVector<std::pair<TF::TPUExecuteOp, tf_device::LaunchOp>, 4>
executes_and_compiles;
llvm::SmallDenseMap<Operation*, int64_t> walk_order;
int64_t next_walk_order = 0;
@ -232,12 +237,17 @@ void TPUDynamicLayoutPass::runOnFunction() {
// Detect tf._TPUCompileMlir -> tf.TPUExecute
auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(op);
if (!execute) return;
auto execute_launch =
llvm::dyn_cast_or_null<tf_device::LaunchOp>(execute.getParentOp());
if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
auto compile = execute.key().getDefiningOp();
if (!compile || compile->getName().getStringRef() != "tf._TPUCompileMlir" ||
!compile->getResult(1).hasOneUse()) {
if (!compile || !compile->getResult(1).hasOneUse()) return;
auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
if (!compile_launch || !compile_launch.WrapsSingleOp() ||
compile_launch.GetBody().front().getName().getStringRef() !=
"tf._TPUCompileMlir")
return;
}
executes_and_compiles.emplace_back(execute, compile);
executes_and_compiles.emplace_back(execute, compile_launch);
});
for (auto execute_and_compile : executes_and_compiles) {
HandleExecute(execute_and_compile.first, execute_and_compile.second,

View File

@ -121,13 +121,13 @@ bool OpAccessesResource(Operation* op) {
// e.g., guaranteed by replication.
// - `check_same_region` specifies whether the reads/assigns need to be in the
// same region as `execute`. This is needed if `execute` is inside ReplicateOp.
VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
bool check_device,
bool check_same_region) {
VariableAccessesForTPUExecute BuildVariableAccessInfo(
tf_device::LaunchOp execute_launch, bool check_device,
bool check_same_region) {
VariableAccessesForTPUExecute infos;
auto device_attr = execute->getAttr(kDeviceAttr);
Attribute device_attr = execute_launch.deviceAttr();
if (check_device && !device_attr) return infos;
auto func = execute->getParentOfType<mlir::FuncOp>();
auto func = execute_launch.getParentOfType<mlir::FuncOp>();
// Track the first read op found, which is used later to check if there are
// assign ops between it and the TPUExecute op. We will exclude reads before
@ -135,21 +135,23 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
// consider resource accesses in other islands since they ordering is enforced
// by inter-island dependencies.
Operation* first_read = nullptr;
Operation& execute = execute_launch.GetBody().front();
// Find inputs that are variable reads.
for (auto operand : llvm::enumerate(execute->getOpOperands())) {
for (auto operand : llvm::enumerate(execute.getOpOperands())) {
infos.new_operand_values.push_back(operand.value().get());
if (!operand.value().get().getDefiningOp()) continue;
auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(
operand.value().get().getDefiningOp());
if (!read_op) continue;
if (check_same_region &&
read_op.getParentRegion() != execute->getParentRegion()) {
read_op.getParentRegion() != execute_launch.getParentRegion()) {
continue;
}
auto resource = read_op.resource();
if (check_device) {
if (auto resource_op = resource.getDefiningOp()) {
// TODO(lyandy): Wrap resource ops in tf_device.launch.
if (auto* resource_op = resource.getDefiningOp()) {
auto resource_attr = resource_op->getAttr(kDeviceAttr);
// Check device matching for the node defining the resource.
if (!resource_attr || resource_attr != device_attr) continue;
@ -169,7 +171,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
if (!emplace_res.second) {
LLVM_DEBUG(llvm::dbgs()
<< "Skipping execute that has multiple reads of a variable: "
<< *execute << "\n");
<< execute << "\n");
infos.per_resource_info.shrink_and_clear();
return infos;
}
@ -191,8 +193,9 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
// work fine for the reads/assigns created by resource lifting, since they are
// placed close to the TPUExecute.
Operation* last_may_modify_resource_access_before_execute = nullptr;
for (Operation& op : llvm::reverse(llvm::make_range(
std::next(first_read->getIterator()), execute->getIterator()))) {
for (Operation& op : llvm::reverse(
llvm::make_range(std::next(first_read->getIterator()),
execute_launch.getOperation()->getIterator()))) {
if (llvm::dyn_cast<TF::ReadVariableOp>(&op)) continue;
if (!OpAccessesResource(&op)) continue;
last_may_modify_resource_access_before_execute = &op;
@ -209,7 +212,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
auto info_it = infos.per_resource_info.find(read.resource());
if (info_it == infos.per_resource_info.end()) continue;
int input_index = info_it->getSecond().execute_input_index;
infos.new_operand_values[input_index] = execute->getOperand(input_index);
infos.new_operand_values[input_index] = execute.getOperand(input_index);
infos.per_resource_info.erase(info_it);
}
infos.resources_read.erase(
@ -227,9 +230,12 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
// Find outputs that are variable assigns.
Operation* last_assign = nullptr;
llvm::SmallPtrSet<Operation*, 8> all_assigns;
llvm::SmallVector<bool, 8> output_fused(execute->getNumResults(), false);
for (int i = 0; i < execute->getNumResults(); ++i) {
auto result = execute->getResult(i);
llvm::SmallVector<bool, 8> output_fused(execute_launch.getNumResults(),
false);
for (int i = 0; i < execute_launch.getNumResults(); ++i) {
// TODO(lyandy): Handle updates to resource writes by remapping to parent
// launch result and checking if launch result is an AssignVariableOp.
auto result = execute_launch.getResult(i);
if (!result.hasOneUse()) continue;
auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(*result.user_begin());
if (!assign_op) continue;
@ -240,7 +246,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
if (info.assign) {
LLVM_DEBUG(llvm::dbgs()
<< "Skipping execute that has multiple assigns of a variable: "
<< *execute << "\n");
<< execute << "\n");
infos.per_resource_info.shrink_and_clear();
return infos;
}
@ -256,8 +262,9 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
// Check if there are other resource accesses after execute.
Operation* first_unknown_resource_access_after_execute = nullptr;
if (last_assign) {
for (auto& op : llvm::make_range(std::next(execute->getIterator()),
last_assign->getIterator())) {
for (auto& op : llvm::make_range(
std::next(execute_launch.getOperation()->getIterator()),
last_assign->getIterator())) {
if (all_assigns.count(&op) > 0) continue;
if (!OpAccessesResource(&op)) continue;
first_unknown_resource_access_after_execute = &op;
@ -282,8 +289,8 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
// Populate infos.old_to_new_output_mapping.
int new_output_index = 0;
infos.old_to_new_output_mapping.resize(execute->getNumResults());
for (int i = 0; i < execute->getNumResults(); ++i) {
infos.old_to_new_output_mapping.resize(execute_launch.getNumResults());
for (int i = 0; i < execute_launch.getNumResults(); ++i) {
if (output_fused[i]) {
infos.old_to_new_output_mapping[i] = -1;
} else {
@ -295,19 +302,20 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
}
// Merges the variable accesses into one TPUExecute op.
void MergeForOneTPUExecute(Operation* execute, bool check_device,
bool check_same_region, OpBuilder* builder) {
void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
bool check_device, bool check_same_region,
OpBuilder* builder) {
auto infos =
BuildVariableAccessInfo(execute, check_device, check_same_region);
BuildVariableAccessInfo(execute_launch, check_device, check_same_region);
if (infos.per_resource_info.empty()) {
return;
}
// Start creating the new TPUExecuteAndUpdateVariables op.
builder->setInsertionPoint(execute);
builder->setInsertionPoint(execute_launch);
// Output types. Skip the original outputs for fused assigns.
llvm::SmallVector<Type, 8> new_output_types;
int old_output_index = 0;
for (const auto& type : execute->getResultTypes()) {
for (const auto& type : execute_launch.getResultTypes()) {
if (infos.old_to_new_output_mapping[old_output_index] >= 0) {
new_output_types.push_back(type);
}
@ -322,7 +330,7 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
device_var_updates_indices.push_back(info.execute_output_index);
}
auto merged_execute = builder->create<TF::TPUExecuteAndUpdateVariablesOp>(
execute->getLoc(), new_output_types, infos.new_operand_values,
execute_launch.getLoc(), new_output_types, infos.new_operand_values,
llvm::ArrayRef<NamedAttribute>{
builder->getNamedAttr(
"device_var_reads_indices",
@ -331,15 +339,24 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
"device_var_updates_indices",
builder->getI64ArrayAttr(device_var_updates_indices))});
if (auto device = execute->getAttr(kDeviceAttr)) {
merged_execute.setAttr(kDeviceAttr, device);
}
// Wrap in launch for device assignment.
auto merged_execute_launch = builder->create<tf_device::LaunchOp>(
merged_execute.getLoc(), execute_launch.deviceAttr(),
merged_execute.getResultTypes());
merged_execute_launch.body().push_back(new Block);
builder->setInsertionPointToEnd(&merged_execute_launch.GetBody());
builder->create<tf_device::ReturnOp>(merged_execute.getLoc(),
merged_execute.getResults());
merged_execute.getOperation()->moveBefore(
merged_execute_launch.GetBody().getTerminator());
// Replace the uses.
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
if (infos.old_to_new_output_mapping[i] < 0) continue;
execute->getResult(i).replaceAllUsesWith(
merged_execute.getResult(infos.old_to_new_output_mapping[i]));
execute_launch.getResult(i).replaceAllUsesWith(
merged_execute_launch.getResult(infos.old_to_new_output_mapping[i]));
}
// Remove the assign ops.
for (const auto& entry : infos.per_resource_info) {
@ -347,7 +364,7 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
if (info.assign) info.assign->erase();
}
// Remove the original TPUExecute op.
execute->erase();
execute_launch.erase();
// Remove the read ops if they have no more uses.
for (const auto& entry : infos.per_resource_info) {
const auto& info = entry.getSecond();
@ -358,19 +375,22 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
void TPUMergeVariablesWithExecutePass::runOnFunction() {
// Find all the executes first, since we will mutate the nodes around each
// execute.
llvm::SmallVector<Operation*, 8> executes;
getFunction().walk([&](TF::TPUExecuteOp op) { executes.push_back(op); });
llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
getFunction().walk([&](tf_device::LaunchOp op) {
if (op.WrapsSingleOp() && llvm::isa<TF::TPUExecuteOp>(op.GetBody().front()))
execute_launches.push_back(op);
});
for (auto execute : executes) {
for (auto execute_launch : execute_launches) {
OpBuilder builder(&getContext());
const bool parent_is_replicate =
llvm::isa<tf_device::ReplicateOp>(execute->getParentOp());
llvm::isa<tf_device::ReplicateOp>(execute_launch.getParentOp());
// If this is inside a tf_device::ReplicateOp, the variables are guaranteed
// to be on the same device as the TPUExecute op. Skip device checking in
// that case, but we need to check that we are only merging reads/assigns
// that are also in this replicated region.
MergeForOneTPUExecute(execute, !parent_is_replicate, parent_is_replicate,
&builder);
MergeForOneTPUExecute(execute_launch, !parent_is_replicate,
parent_is_replicate, &builder);
}
}

View File

@ -255,6 +255,26 @@ LogicalResult SetMetadataProtoFromLaunchFuncOp(
return success();
}
// Wraps single op in `tf_device.launch` for explicit device assignment.
tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc,
Operation* op, llvm::StringRef device) {
OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
auto launch = builder->create<tf_device::LaunchOp>(
loc, builder->getStringAttr(device), op->getResultTypes());
launch.body().push_back(new Block);
builder->setInsertionPointToEnd(&launch.GetBody());
builder->create<tf_device::ReturnOp>(loc, op->getResults());
// Move op inside launch.
op->moveBefore(launch.GetBody().getTerminator());
builder->restoreInsertionPoint(insert_point);
return launch;
}
// Create a `tf._TPUCompileMlir` that contains a MLIR module that is
// functionally equivalent to the function referenced by launch_func.
Operation* BuildCompileOp(
@ -319,9 +339,6 @@ Operation* BuildCompileOp(
compile_op_state.addAttribute("mlir_module",
builder->getStringAttr(txt_module));
compile_op_state.addAttribute(kDeviceAttr,
builder->getStringAttr(compilation_device));
// Result #0 is a string indicating whether compilation is successful or not.
compile_op_state.addTypes(
RankedTensorType::get({}, builder->getType<TF::StringType>()));
@ -330,7 +347,10 @@ Operation* BuildCompileOp(
compile_op_state.addTypes(
RankedTensorType::get({}, builder->getType<TF::StringType>()));
return builder->createOperation(compile_op_state);
Operation* compile_op = builder->createOperation(compile_op_state);
return WrapOpInLaunch(builder, compile_op->getLoc(), compile_op,
compilation_device);
}
// Creates a `tf.TPUExecute` op that executes TPU program generated by
@ -377,6 +397,7 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op,
// For each logical core, create a region with TPUExecute op.
for (int core_id = 0; core_id < num_logical_cores; ++core_id) {
auto& region = parallel_execute_op.GetRegionBlockWithIndex(core_id);
builder->setInsertionPointToEnd(&region);
// Create Execute op.
//
@ -389,20 +410,9 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op,
//
// TODO(b/149102679): Add device attribute to launch op once device
// topology for multiple logical cores can be correctly parsed.
builder->setInsertionPointToStart(&region);
auto region_launch_op = builder->create<tf_device::LaunchOp>(
region.getParent()->getLoc(), builder->getStringAttr(""),
launch_func.results().getTypes());
region_launch_op.body().push_back(new Block);
auto region_launch_op = WrapOpInLaunch(
builder, region.getParent()->getLoc(), execute, /*device=*/"");
builder->setInsertionPointToEnd(&region_launch_op.GetBody());
builder->create<tf_device::ReturnOp>(region_launch_op.getLoc(),
execute->getResults());
// Move execute inside the launch op.
execute->moveBefore(&region_launch_op.GetBody().back());
builder->setInsertionPointToEnd(&region);
builder->create<tf_device::ReturnOp>(region.getParent()->getLoc(),
region_launch_op.getResults());
}
@ -423,13 +433,14 @@ void RemapOutputsOfParallelExecute(tf_device::LaunchFuncOp launch_func,
std::get<0>(outputs).replaceAllUsesWith(std::get<1>(outputs));
}
void AssignDevicesToReplicatedExecute(
tf_device::LaunchOp AssignDevicesToReplicatedExecute(
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
tf_device::ReplicateOp replicate, Operation* execute_op,
OpBuilder* builder) {
// If computation is replicated, execution devices are assigned to the
// replicate. Otherwise there is only one execution device and the device is
// assigned to the execute op.
std::string device;
if (replicate) {
// Model parallelism is not support for now. Therefore, assign all ops
// in replicate op with virtual device alias specifying that ops will be
@ -439,24 +450,27 @@ void AssignDevicesToReplicatedExecute(
for (const auto& replica_execution_devices : execution_devices)
replicate_execution_devices.push_back(replica_execution_devices.front());
device = tensorflow::GetDeviceAliasForLogicalCore(0);
auto device_attr = builder->getNamedAttr(
tensorflow::GetDeviceAliasForLogicalCore(0),
builder->getStrArrayAttr(replicate_execution_devices));
device, builder->getStrArrayAttr(replicate_execution_devices));
replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attr));
} else {
execute_op->setAttr(
kDeviceAttr, builder->getStringAttr(execution_devices.front().front()));
device = execution_devices.front().front();
}
return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
}
// Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
// status of `compile_op` to check whether compilation is successful.
void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
llvm::StringRef compilation_device,
OpBuilder* builder) {
OperationState assert_op_state(compile_op->getLoc(),
"tf.TPUCompileSucceededAssert");
assert_op_state.addOperands(compile_op->getResult(0));
builder->createOperation(assert_op_state);
Operation* assert_op = builder->createOperation(assert_op_state);
WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
}
// Rewrites a `tf_device.launch_func` operation into a set of TPU Runtime
@ -547,7 +561,7 @@ LogicalResult Rewrite(
<< "error in fetching TPU compilation/execution devices: "
<< status_or_tpu_device_assignment.status().error_message();
// Create compile op;
// Create compile op.
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
builder->setInsertionPoint(launch_func);
Operation* compile_op = BuildCompileOp(
@ -564,24 +578,25 @@ LogicalResult Rewrite(
for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
BuildTPUCompileSucceededAssertOp(compile_op, builder);
BuildTPUCompileSucceededAssertOp(
compile_op, tpu_device_assignment.compilation_device, builder);
Operation* execute_op;
if (num_cores_per_replica > 1) {
// For model parallelism, tf_device.parallel_execute is used to express
// concurrent device execution across multiple logical devices.
execute_op = BuildParallelExecuteOp(num_cores_per_replica, compile_op,
launch_func, builder);
Operation* execute_op = BuildParallelExecuteOp(
num_cores_per_replica, compile_op, launch_func, builder);
RemapOutputsOfParallelExecute(launch_func, execute_op);
// TODO(hongjunchoi): Correctly parse TPU topology and assign logical device
// attributes to launch_op's within parallel_execute op.
} else {
execute_op = BuildExecuteOp(compile_op, launch_func, builder);
AssignDevicesToReplicatedExecute(tpu_device_assignment.execution_devices,
replicate, execute_op, builder);
launch_func.replaceAllUsesWith(execute_op);
Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder);
tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
tpu_device_assignment.execution_devices, replicate, execute_op,
builder);
launch_func.replaceAllUsesWith(launch_op);
}
launch_func.erase();

View File

@ -0,0 +1,201 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/UseDefLists.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kXlaShardingAttr[] = "_XlaSharding";
constexpr char kInputShardingAttr[] = "input_sharding_configuration";
constexpr char kOutputShardingAttr[] = "output_sharding_configuration";
struct TPUShardingIdentificationPass
: public ModulePass<TPUShardingIdentificationPass> {
void runOnModule() override;
};
// XlaSharding op may be direct user of inputs but it may also be followed by
// an Identity op and, in the case where bfloat16 type is used, Cast op may be
// added right after the input. As so, parse the users of the operation to
// access connected XlaSharding op.
//
// TODO(hongjunchoi): Consider explicitly checking op patterns to detect
// sharded inputs.
void GetAdjacentToXlaShardingOp(
Operation* op, llvm::Optional<TF::XlaShardingOp>* sharding_op) {
// TODO(hongjunchoi): Detect the case when sharding configuration is
// ambiguous for a single input (i.e. multiple different XlaSharding ops
// with different configuration policies are connected).
if (sharding_op->hasValue()) return;
if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(op)) {
sharding_op->emplace(sharding);
return;
}
if (llvm::isa<TF::IdentityOp>(op) || llvm::isa<TF::CastOp>(op)) {
for (auto user : op->getUsers())
GetAdjacentToXlaShardingOp(user, sharding_op);
}
}
llvm::Optional<StringRef> ParseShardingAttribute(Operation* operation) {
const auto& sharding_attr =
operation->getAttrOfType<StringAttr>(kXlaShardingAttr);
if (!sharding_attr) return llvm::Optional<StringRef>();
return sharding_attr.getValue();
}
// Parse XlaSharding op connected to input args. If Input to
// tf_device.LaunchFunc op is of resource type, then XlaSharding op
// will be connected to following ReadVariable op.
//
// TODO(hongjunchoi): Add logic to parse XlaSharding op inside a
// Call op or if/while op.
llvm::Optional<StringRef> ParseInputSharding(const FuncOp func,
const int arg_index,
const Value& arg) {
llvm::Optional<TF::XlaShardingOp> parsed_sharding_op;
for (auto user : arg.getUsers()) {
if (parsed_sharding_op) continue;
GetAdjacentToXlaShardingOp(user, &parsed_sharding_op);
if (parsed_sharding_op) continue;
if (llvm::isa<TF::ReadVariableOp>(user))
for (auto read_variable_user : user->getUsers())
GetAdjacentToXlaShardingOp(read_variable_user, &parsed_sharding_op);
}
if (!parsed_sharding_op) return llvm::Optional<StringRef>();
return ParseShardingAttribute(parsed_sharding_op->getOperation());
}
// If operand of return value of tf_device.LaunchFunc op is directly from
// XlaSharding op, return the provided sharding configuration.
llvm::Optional<StringRef> ParseReturnValueSharding(FuncOp func,
const int output_index,
const OpOperand& operand) {
if (auto sharding_op =
llvm::dyn_cast<TF::XlaShardingOp>(operand.get().getDefiningOp())) {
return ParseShardingAttribute(sharding_op.getOperation());
}
return llvm::Optional<StringRef>();
}
// Add parsed sharding configuration to tf_device.LaunchFunc op attribute.
void SetShardingConfigurationAsAttribute(
tf_device::LaunchFuncOp launch_func, const std::string& attr_name,
const llvm::SmallVector<std::string, 8>& sharding_config) {
auto input_sharding_array_ref = llvm::SmallVector<llvm::StringRef, 8>(
sharding_config.begin(), sharding_config.end());
launch_func.setAttr(attr_name,
mlir::Builder(launch_func.getContext())
.getStrArrayAttr(input_sharding_array_ref));
}
// If XlaSharding op is connected to input/output of the tf_device.LaunchFuncOp,
// then add attributes to the op specifying the sharding configurations.
void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) {
// Look up function definition from module.
FuncOp func = launch_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
launch_func.func());
Block& func_entry_block = func.getBody().getBlocks().front();
// By default inputs have maximal sharding and inputs are assigned to
// logical core 0 if no sharding is defined.
llvm::SmallVector<std::string, 8> sharding_for_args(
func_entry_block.getNumArguments(),
xla::sharding_builder::AssignDevice(0).SerializeAsString());
// Iterate through input arguments to the entry block of tf_device.LaunchFunc.
// For input ops, look for following XlaSharding ops. XlaSharding ops can
// 1) Directly follow the input argument if input argument has non-resource
// types.
// 2) Follow ReadVariableOp if the input type is of resource type.
// 3) Follow IdentityOp or CastOp after above cases (1), (2).
for (auto& arg_index_and_value :
llvm::enumerate(func_entry_block.getArguments())) {
const int arg_index = arg_index_and_value.index();
auto& arg = arg_index_and_value.value();
auto input_arg_sharding = ParseInputSharding(func, arg_index, arg);
if (!input_arg_sharding.hasValue()) continue;
sharding_for_args[arg_index] = input_arg_sharding->str();
}
SetShardingConfigurationAsAttribute(launch_func, kInputShardingAttr,
sharding_for_args);
// By default return values from logical core 0 is used if no sharding
// configuration is defined.
llvm::SmallVector<std::string, 8> sharding_for_return_values(
func_entry_block.getTerminator()->getNumOperands(),
xla::sharding_builder::AssignDevice(0).SerializeAsString());
// Iterate through operands of the terminator, if the preceding op is
// XlaShardingOp, then add provided sharding configuration to launch func
// attribute.
for (auto& return_value_and_index :
llvm::enumerate(func_entry_block.getTerminator()->getOpOperands())) {
int return_value_index = return_value_and_index.index();
const auto& return_value = return_value_and_index.value();
auto return_val_sharding =
ParseReturnValueSharding(func, return_value_index, return_value);
if (return_val_sharding)
sharding_for_return_values[return_value_index] =
return_val_sharding->str();
}
SetShardingConfigurationAsAttribute(launch_func, kOutputShardingAttr,
sharding_for_return_values);
}
void TPUShardingIdentificationPass::runOnModule() {
getModule().walk([&](tf_device::LaunchFuncOp launch_func) {
IdentifyXlaShardingForTPUComputation(launch_func);
});
}
} // anonymous namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass() {
return std::make_unique<TPUShardingIdentificationPass>();
}
static PassRegistration<TPUShardingIdentificationPass> pass(
"tf-tpu-sharding-identification",
"Identifies and handles inputs/outputs of TPU computation that is "
"sharded across logical cores.");
} // namespace TFTPU
} // namespace mlir

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Support/STLExtras.h" // TF:llvm-project
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
@ -143,8 +144,8 @@ Value SkipIdentity(Value v, bool allow_other_use,
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
TF::WhileOp while_op, tf_device::ReplicateOp replicate,
TF::TPUExecuteAndUpdateVariablesOp execute, Operation* compile, FuncOp body,
FuncOp cond) {
TF::TPUExecuteAndUpdateVariablesOp execute,
tf_device::LaunchOp compile_launch, FuncOp body, FuncOp cond) {
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
auto mirrored_variable_indices_attr =
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
@ -168,7 +169,8 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
if (replicate_arg_to_execute_arg.empty()) return mapping;
// Parse the original compile metadata.
auto metadata_str = compile->getAttrOfType<StringAttr>("metadata");
Operation& compile = compile_launch.GetBody().front();
auto metadata_str = compile.getAttrOfType<StringAttr>("metadata");
assert(metadata_str && "Missing compilation metadata");
tensorflow::tpu::TPUCompileMetadataProto metadata;
metadata.ParseFromString(std::string(metadata_str.getValue()));
@ -242,8 +244,8 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
}
}
// Update the metadata of the compile op.
compile->setAttr("metadata", OpBuilder(compile).getStringAttr(
metadata.SerializeAsString()));
compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
compile.getContext()));
return mapping;
}
@ -393,26 +395,56 @@ llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
return state_vars;
}
// Wraps single op in `tf_device.launch` for explicit device assignment.
void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
llvm::StringRef device) {
OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
auto launch = builder->create<tf_device::LaunchOp>(
loc, builder->getStringAttr(device), op->getResultTypes());
launch.body().push_back(new Block);
builder->setInsertionPointToEnd(&launch.GetBody());
builder->create<tf_device::ReturnOp>(loc, op->getResults());
// Move op inside launch.
op->moveBefore(launch.GetBody().getTerminator());
builder->restoreInsertionPoint(insert_point);
}
// Performs the transformation for a replicate op inside a while loop.
void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
MLIRContext* context) {
int64_t num_replicas = replicate.n().getLimitedValue();
if (num_replicas == 1) return;
TF::TPUExecuteAndUpdateVariablesOp execute;
for (auto execute_op :
replicate.GetBody().getOps<TF::TPUExecuteAndUpdateVariablesOp>()) {
if (execute == nullptr) {
execute = execute_op;
tf_device::LaunchOp execute_launch;
for (auto execute_launch_op :
replicate.GetBody().getOps<tf_device::LaunchOp>()) {
if (!execute_launch_op.WrapsSingleOp() ||
!llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
execute_launch_op.GetBody().front()))
continue;
if (execute_launch == nullptr) {
execute_launch = execute_launch_op;
} else {
// We only support one execute op inside replicate.
execute = nullptr;
execute_launch = nullptr;
break;
}
}
if (!execute) return;
if (!execute_launch) return;
auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
execute_launch.GetBody().front());
auto compile =
SkipIdentity(execute.key(), /*allow_other_use=*/true).getDefiningOp();
if (!compile) return;
auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
if (!compile_launch || !compile_launch.WrapsSingleOp() ||
compile_launch.GetBody().front().getName().getStringRef() !=
"tf._TPUCompileMlir")
return;
auto module = while_op.getParentOfType<ModuleOp>();
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
@ -421,7 +453,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
// Analyze the formattable inputs.
auto execute_arg_to_outer_args =
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
while_op, replicate, execute, compile, body, cond);
while_op, replicate, execute, compile_launch, body, cond);
if (execute_arg_to_outer_args.empty()) return;
// Extract the replicated devices.
@ -468,13 +500,15 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
for (const auto& entry : execute_arg_to_outer_args) {
reformat_operands.push_back(execute.args()[entry.first]);
}
reformat_operands.push_back(compile->getResult(1));
reformat_operands.push_back(compile_launch.getResult(1));
reformat_operands.push_back(replicate.GetBody().getArgument(
replicate.GetBody().getNumArguments() - 1));
builder.setInsertionPoint(execute);
builder.create<TF::TPUReshardVariablesOp>(
execute.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands,
builder.setInsertionPoint(execute_launch);
auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands,
llvm::ArrayRef<NamedAttribute>{});
WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
execute_launch.device());
// Build the replicated unformat op after the loop. First prepare building the
// replicate op.
@ -514,9 +548,11 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
unformat_operands.begin() + unformat_operands.size() - 1,
default_state_key.getResult());
// Unformat op.
builder.create<TF::TPUReshardVariablesOp>(
auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands,
llvm::ArrayRef<NamedAttribute>{});
WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
execute_launch.device());
builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
}

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
#include <climits>
#include <cstdint>
@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
@ -35,16 +36,11 @@ limitations under the License.
#include "mlir/Support/Functional.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/util/matmul_bcast.h"
namespace mlir {
namespace TFL {
namespace TF {
namespace {
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
@ -75,7 +71,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
Type resultType = RankedTensorType::get(shape, element_type);
auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
auto shape_tensor =
rewriter.create<ConstantOp>(loc, shape_spec_type, constant_attr);
rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
/*shape=*/shape_tensor);
}
@ -108,8 +104,8 @@ std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
auto begin_attr =
DenseElementsAttr::get<int64_t>(vector3_type, {batch_idx, 0, 0});
auto size_attr = DenseElementsAttr::get<int64_t>(vector3_type, slice_size);
auto begin = rewriter.create<ConstantOp>(loc, vector3_type, begin_attr);
auto size = rewriter.create<ConstantOp>(loc, vector3_type, size_attr);
auto begin = rewriter.create<TF::ConstOp>(loc, vector3_type, begin_attr);
auto size = rewriter.create<TF::ConstOp>(loc, vector3_type, size_attr);
auto slice_op = rewriter.create<TF::SliceOp>(loc, slice_result_type,
/*input=*/reshape_op.output(),
begin, size);
@ -312,8 +308,12 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
}
static PassRegistration<UnrollBatchMatMulPass> pass(
"tfl-unroll-batch-matmul",
"tf-unroll-batch-matmul",
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
} // namespace TFL
std::unique_ptr<OpPassBase<FuncOp>> CreateUnrollBatchMatMulPassPass() {
return std::make_unique<UnrollBatchMatMulPass>();
}
} // namespace TF
} // namespace mlir

View File

@ -1,4 +1,4 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#include "llvm/ADT/ArrayRef.h"
#include "mlir/IR/Location.h" // TF:llvm-project
@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/util/matmul_bcast.h"
namespace mlir {
namespace TFL {
namespace TF {
// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not
// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape,
@ -53,7 +53,7 @@ class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
PatternRewriter& rewriter) const override;
};
} // namespace TFL
} // namespace TF
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_

View File

@ -40,11 +40,13 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/Verifier.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Identifier.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
@ -63,6 +65,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
@ -162,7 +165,8 @@ class ImporterBase {
specs_(specs),
debug_info_(debug_info),
function_name_for_debug_info_(function_name_for_debug_info),
function_name_uniquifier_(function_name_uniquifier) {}
function_name_uniquifier_(function_name_uniquifier),
error_handler_(module.getContext()) {}
// Returns the inferred function signature of the given function body. Input
// types are unranked tensor of the respective datatype in the function and
@ -318,8 +322,9 @@ class ImporterBase {
// the location.
mlir::Location GetLocation(const NodeDef& node);
// Gets the location information string for the given node.
std::string GetLocationStr(const Node& node, bool includeNodeName = false);
// Appends the location string for the node to the error message and returns
// the combined error status.
Status EmitErrorWithLocationStr(const Node& node, const Status& error_status);
// Inserts a placeholder node in the graph to replace a feed output tensor,
// and returns the new placeholder node and a boolean indicating if the
@ -368,6 +373,7 @@ class ImporterBase {
NodeValueMap node_values_;
std::unique_ptr<ShapeRefiner> shape_refiner_;
NameUniquifier* function_name_uniquifier_;
mlir::StatusScopedDiagnosticHandler error_handler_;
protected:
// Maps feed as TensorId to new Placeholder node name.
@ -669,9 +675,10 @@ Status ImporterBase::AddNodesToShapeRefiner() {
remapped_feeds_[{it->first, index}] = placeholder_node->name();
node_name_map[placeholder_node->name()] = placeholder_node;
// Add the new placeholder node to the shape refiner.
TF_RETURN_WITH_CONTEXT_IF_ERROR(
shape_refiner_->AddNode(placeholder_node),
GetLocationStr(*placeholder_node));
Status status = shape_refiner_->AddNode(placeholder_node);
if (!status.ok()) {
return EmitErrorWithLocationStr(*placeholder_node, status);
}
}
} else {
auto index_it = it->second.find(0);
@ -691,8 +698,10 @@ Status ImporterBase::AddNodesToShapeRefiner() {
}
if (!node_added_to_shape_refiner) {
// Add the node to the shape refiner if the node hasn't been removed.
TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node),
GetLocationStr(*node));
Status status = shape_refiner_->AddNode(node);
if (!status.ok()) {
return EmitErrorWithLocationStr(*node, status);
}
}
auto set_shape_from_list_attr = [&](const AttrValue* attr) {
@ -700,9 +709,11 @@ Status ImporterBase::AddNodesToShapeRefiner() {
for (auto shape : llvm::enumerate(list.shape())) {
auto* node_context = shape_refiner_->GetContext(node);
shape_inference::ShapeHandle handle;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
node_context->MakeShapeFromShapeProto(shape.value(), &handle),
GetLocationStr(*node));
Status status =
node_context->MakeShapeFromShapeProto(shape.value(), &handle);
if (!status.ok()) {
return EmitErrorWithLocationStr(*node, status);
}
node_context->set_output(shape.index(), handle);
}
return Status::OK();
@ -735,9 +746,11 @@ Status ImporterBase::AddNodesToShapeRefiner() {
DCHECK(node_context != nullptr);
if (const AttrValue* attr = node->attrs().Find("shape")) {
shape_inference::ShapeHandle handle;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
node_context->MakeShapeFromShapeProto(attr->shape(), &handle),
GetLocationStr(*node));
Status status =
node_context->MakeShapeFromShapeProto(attr->shape(), &handle);
if (!status.ok()) {
return EmitErrorWithLocationStr(*node, status);
}
node_context->set_output(0, handle);
} else if (const AttrValue* attr = node->attrs().Find("_output_shapes")) {
TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
@ -813,9 +826,12 @@ Status ImporterBase::AddNodesToShapeRefiner() {
existing.push_back(shape_context->output(o));
}
bool inferred = false;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred),
GetLocationStr(*node));
shape_inference::ShapeHandle handle;
Status status =
shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred);
if (!status.ok()) {
return EmitErrorWithLocationStr(*node, status);
}
for (int o = 0; o < shape_context->num_outputs(); ++o) {
if (!same_inferred_shape(shape_context, shape_context->output(o),
existing[o])) {
@ -1346,18 +1362,11 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
}
}
std::string ImporterBase::GetLocationStr(const Node& node,
bool includeNodeName) {
const auto location = GetLocation(node.def());
std::string s;
llvm::raw_string_ostream ss(s);
location.print(ss);
ss.flush();
// Removes the node name prefix if it exists.
if (!s.empty() && s[0] == '\"' && s.find_first_of(node.name()) == 1) {
return s.replace(0, node.name().size() + 3, "");
}
return s;
Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
const Status& error_status) {
const mlir::Location location = GetLocation(node.def());
mlir::emitError(location);
return error_handler_.Combine(error_status);
}
mlir::Operation* ImporterBase::createOperation(

View File

@ -210,6 +210,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
bool use_tuple_args, bool return_tuple) {
mlir::PassManager tf2xla(module_op.getContext());
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass());
tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
// LegalizeTFControlFlow encapsulates arguments for control flow operations
@ -221,6 +222,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
// and canonicalization opportunities that are necessary for the second
// LegalizeTFPass(allow_partial_conversion=false) invocation.
tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(true));
tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass());
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addNestedPass<mlir::FuncOp>(
mlir::xla_hlo::createLegalizeTFPass(false));

View File

@ -49,7 +49,9 @@ Status LoadProtoFromFile(absl::string_view input_filename,
const auto file_or_err =
llvm::MemoryBuffer::getFileOrSTDIN(StringViewToRef(input_filename));
if (std::error_code error = file_or_err.getError()) {
return errors::InvalidArgument("Could not open input file");
return errors::InvalidArgument(
"Could not open input file ",
string(input_filename.data(), input_filename.size()).c_str());
}
const auto& input_file = *file_or_err;

View File

@ -433,6 +433,7 @@ cc_library(
includes = ["include"],
deps = [
":convert_op_folder",
":hlo",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/service:hlo",
"@llvm-project//mlir:IR",

View File

@ -198,7 +198,8 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
TF_ASSIGN_OR_RETURN(auto result_type, ConvertType(instruction->shape()));
TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType<RankedTensorType>(
instruction->shape(), *builder_));
llvm::SmallVector<NamedAttribute, 10> attributes = {builder_->getNamedAttr(
"name", builder_->getStringAttr(instruction->name()))};
mlir::Location loc = func_builder->getUnknownLoc();
@ -699,29 +700,12 @@ StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
return operands;
}
StatusOr<mlir::Type> HloFunctionImporter::ConvertType(const Shape& shape) {
if (shape.IsToken()) {
return mlir::xla_hlo::TokenType::get(builder_->getContext());
}
if (shape.IsTuple()) {
llvm::SmallVector<mlir::Type, 4> contents;
contents.reserve(shape.tuple_shapes_size());
for (const auto& subtype : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(auto mlir_subtype, ConvertType(subtype));
contents.push_back(mlir_subtype);
}
return builder_->getTupleType(contents);
}
return ConvertTensorShapeToType<RankedTensorType>(shape, *builder_);
}
tensorflow::Status HloFunctionImporter::GetMlirTypes(
const std::vector<HloInstruction*>& instructions,
llvm::SmallVectorImpl<mlir::Type>* types) {
for (auto instruction : instructions) {
TF_ASSIGN_OR_RETURN(auto ret_type, ConvertType(instruction->shape()));
TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
instruction->shape(), *builder_));
types->push_back(ret_type);
}
return tensorflow::Status::OK();

View File

@ -78,9 +78,6 @@ class HloFunctionImporter {
// Converts xla Tensor type to the corresponding MLIR type.
StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape);
// Converts xla Primitive types to the corresponding MLIR type.
StatusOr<mlir::Type> ConvertType(const xla::Shape& shape);
// Returns the output type of an HloInstruction.
StatusOr<mlir::Type> GetReturnType(xla::HloInstruction* instruction);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
@ -69,6 +70,9 @@ static StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
}
return builder.getTupleType(contents);
}
if (shape.IsToken()) {
return mlir::xla_hlo::TokenType::get(builder.getContext());
}
return ConvertTensorShapeToType<TypeT>(shape, builder);
}

View File

@ -32,11 +32,13 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Dialect.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" // TF:llvm-project
#include "mlir/IR/OpDefinition.h" // TF:llvm-project
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
@ -46,6 +48,7 @@ limitations under the License.
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
@ -479,6 +482,48 @@ static LogicalResult Verify(BroadcastInDimOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// ScalarsToDimensionTensorOp
//===----------------------------------------------------------------------===//
namespace {
// Canonicalizes the pattern of the form
//
// %2 = "xla_hlo.scalars_to_dimension_tensor"(%0, %1)
// : (i32, i32) -> tensor<2xi32>
// %3 = extract_element %2[%c0] : tensor<2xi32>
//
// to just %0.
struct ExtractElementFromScalarsToDimensionTensor
: public OpRewritePattern<ExtractElementOp> {
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(ExtractElementOp extract,
PatternRewriter& rewriter) const override {
if (extract.indices().size() != 1) return matchFailure();
if (auto scalars_to_tensor = dyn_cast_or_null<ScalarsToDimensionTensorOp>(
extract.aggregate().getDefiningOp())) {
APInt index;
if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) {
return matchFailure();
}
rewriter.replaceOp(extract,
scalars_to_tensor.getOperand(index.getZExtValue()));
return matchSuccess();
}
return matchFailure();
}
};
} // namespace
void ScalarsToDimensionTensorOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<ExtractElementFromScalarsToDimensionTensor>(context);
}
//===----------------------------------------------------------------------===//
// DynamicBroadcastInDimOp
//===----------------------------------------------------------------------===//

View File

@ -784,11 +784,12 @@ def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor",
compute shape arguments to dynamic operations.
}];
let arguments = (ins Variadic<AnySignlessInteger>);
let arguments = (ins Variadic<AnySignlessInteger>:$scalars);
let results = (outs HLO_DimensionTensor);
// Cannot be exported to legacy formats.
let hasCustomHLOConverter = 1;
let hasCanonicalizer = 1;
}
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",

View File

@ -64,3 +64,13 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: @extract_scalars_to_tensor
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
func @extract_scalars_to_tensor(%arg0: i32, %arg1: i32) -> i32 {
%0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32>
%1 = constant 0 : index
%2 = extract_element %0[%1] : tensor<2xi32>
// CHECK: return %[[ARG0]]
return %2 : i32
}

View File

@ -453,6 +453,14 @@ func @rng_uniform_invalid_type(%mu: tensor<complex<f32>>, %sigma: tensor<f32>) -
// -----
// CHECK-LABEL: @scalars_to_dimension_tensor
func @scalars_to_dimension_tensor(%arg0: i32, %arg1: i32) -> tensor<2xi32> {
%0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: func @select
func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

View File

@ -35,7 +35,7 @@ from tensorflow.python.ops import image_ops
from tensorflow.python.platform import test
def GenerateNumpyRandomRGB(shape):
def _generate_numpy_random_rgb(shape):
# Only generate floating points that are fractions like n / 256, since they
# are RGB pixels. Some low-precision floating point types in this test can't
# handle arbitrary precision floating points well.
@ -51,7 +51,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
shape = (batch_size, 2, 7, 3)
for nptype in self.float_types:
inp = GenerateNumpyRandomRGB(shape).astype(nptype)
inp = _generate_numpy_random_rgb(shape).astype(nptype)
# Convert to HSV and back, as a batch and individually
with self.session() as sess:
@ -89,7 +89,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
def testRGBToHSVNumpy(self):
"""Tests the RGB to HSV conversion matches a reference implementation."""
for nptype in self.float_types:
rgb_flat = GenerateNumpyRandomRGB((64, 3)).astype(nptype)
rgb_flat = _generate_numpy_random_rgb((64, 3)).astype(nptype)
rgb_np = rgb_flat.reshape(4, 4, 4, 3)
hsv_np = np.array([
colorsys.rgb_to_hsv(

View File

@ -153,6 +153,10 @@ class TRTEngineOp : public AsyncOpKernel {
// Whether to use implicit batch dimension for TensorRT.
bool use_implicit_batch_;
// Whether to collect optimization profiles for TensorRT, only used when
// use_implicit_batch_=false.
bool profile_generation_mode_;
// Whether to build TensorRT engines at runtime.
bool allow_build_at_runtime_;
@ -322,7 +326,19 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
use_implicit_batch_ = true;
}
#endif
status =
context->GetAttr("_profile_generation_mode", &profile_generation_mode_);
if (status.code() == tensorflow::error::NOT_FOUND) {
VLOG(2) << "Not found _profile_generation_mode in "
<< context->device()->name()
<< ", thus setting _profile_generation_mode=false";
profile_generation_mode_ = false;
}
if (use_implicit_batch_) {
OP_REQUIRES(context, !profile_generation_mode_,
errors::InvalidArgument(
"profile_generation_mode_=true is only supported if "
"use_implicit_batch=false"));
if (input_partial_shapes_.empty()) {
VLOG(1) << "Attribute input_shapes is not set. This happens probably "
<< "because you are using a model that is already converted "
@ -547,11 +563,21 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_concrete_shapes), *helper);
if (!use_implicit_batch_) {
if (cache_res->profiles_.GetNumProfiles() == 0) {
// Create a single profile from the current input shape. In the future we
// will collect a set of input shapes during build mode and create
// profiles for each of them.
if (profile_generation_mode_) {
// Collecting new shapes for profiles can be only done once. After the
// shapes are converted to TRT profiles, no shapes can be collected
// anymore.
OP_REQUIRES(ctx, cache_res->profiles_.GetNumProfiles() == 0,
errors::Unimplemented("Cannot collect new shapes when "
"profiles are already created."));
// Just collect the input shape info and return. The shapes are used to
// generate optimization profiles during engine creation.
cache_res->profiles_.AddShape(input_concrete_shapes);
VLOG(1) << "Native segment is used during collecting shapes for profiles";
ExecuteNativeSegment(ctx, helper);
return;
} else if (cache_res->profiles_.GetNumProfiles() == 0) {
// Create profiles out of collected shapes during profile generation.
cache_res->profiles_.InitProfiles();
}
}

View File

@ -238,12 +238,16 @@ TEST_F(TRTEngineOpTestBase, ExplicitBatch) {
device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
core::ScopedUnref sc(cache_resource);
// The cache should contain only one EngineContext, with a valid cuda_engine.
// Due to the way the engine lookup is implemented, explicit batch mode
// requires profile generation. Currently profile generaton is not enabled in
// this test therfore engine creation fails.
//
// TODO(Tamas) find a way to enable profile generation mode and test it
auto cache = &cache_resource->cache_;
EXPECT_EQ(1, cache->size());
ASSERT_EQ(1, cache->count({input_shape}));
EngineContext* ectx = cache->at({input_shape}).get();
EXPECT_NE(ectx->cuda_engine, nullptr);
EXPECT_EQ(0, cache->size());
// ASSERT_EQ(1, cache->count({input_shape}));
// EngineContext* ectx = cache->at({input_shape}).get();
// EXPECT_NE(ectx->cuda_engine, nullptr);
}
TEST_F(TRTEngineOpTestBase, DynamicShapes) {
@ -267,12 +271,13 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) {
device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
core::ScopedUnref sc(cache_resource);
// The cache should contain only one EngineContext.
// We did not have profile generation mode therfore engine creation failed.
// TODO(Tamas) find a way to enable profile generation mode and test it
auto cache = &cache_resource->cache_;
EXPECT_EQ(1, cache->size());
ASSERT_EQ(1, cache->count({input_shape}));
EngineContext* ectx = cache->at({input_shape}).get();
EXPECT_NE(ectx->cuda_engine, nullptr);
EXPECT_EQ(0, cache->size());
// ASSERT_EQ(1, cache->count({input_shape}));
// EngineContext* ectx = cache->at({input_shape}).get();
// EXPECT_NE(ectx->cuda_engine, nullptr);
}
template <typename T>

View File

@ -70,7 +70,15 @@ void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment,
// TODO(aaroey): AllocateRaw takes size_t size as input, so it'll produce
// unexpected result when TRT tries to allocate more bytes than size_t can
// carry. Fix this.
void* mem = allocator_->AllocateRaw(alignment, total_size);
//
// Fail immediately if allocation fails, rather than waiting 10 seconds and
// failing then anyway.
// TensorRT 7 can also switch to a different algorithm for a layer if an
// algorithm uses too much memory. If we don't fail immediately building the
// engine can be *very* slow with TensorRT7 when GPU memory is limited.
AllocationAttributes attributes;
attributes.no_retry_on_failure = true;
void* mem = allocator_->AllocateRaw(alignment, total_size, attributes);
if (!mem) return nullptr;
void* alloc_mem = mem;

View File

@ -1,4 +1,4 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test", "tf_openmp_copts")
load(
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
@ -11,6 +11,7 @@ load(
load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library")
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("//tensorflow/compiler/xla/service/cpu:build_defs.bzl", "runtime_copts")
package(
default_visibility = [":internal"],
@ -180,6 +181,76 @@ cc_library(
],
)
# The filegroups below are explicitly used by
# tensorflow/tools/pip_package:build_pip_package to ensure we include the proper
# sources for the XLA AOT CPU runtime; as these are necessary outside of bazel
# when linking tfcompile objects using saved_model_cli (e.g. using the
# tensorflow pip package). The associated .cc files are included in tensorflow
# pip package's xla_aot_runtime_srcs/ subdirectory. All necessary headers are
# also included in the pip package's include/tensorflow/ and include/external/
# subdirectories. Note however that sometimes additional object files may need
# to be linked when linking aot xla objects, e.g. abseil libraries. See the deps
# attribute of the "xla_compiled_cpu_runtime_standalone" target below for an
# exhaustive list.
filegroup(
name = "xla_compiled_cpu_runtime_hdrs",
srcs = [
"xla_compiled_cpu_function.h",
"//tensorflow/compiler/xla:cpu_runtime_hdrs",
"//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_hdrs",
"//tensorflow/core/kernels:xla_cpu_runtime_hdrs",
"//tensorflow/core/platform:xla_cpu_runtime_srcs",
],
visibility = ["//tensorflow/tools/pip_package:__pkg__"],
)
filegroup(
name = "xla_compiled_cpu_runtime_srcs",
srcs = [
"xla_compiled_cpu_function.cc",
"//tensorflow/compiler/xla:cpu_runtime_srcs",
"//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_srcs",
"//tensorflow/core/kernels:xla_cpu_runtime_srcs",
"//tensorflow/core/platform:xla_cpu_runtime_srcs",
],
visibility = ["//tensorflow/tools/pip_package:__pkg__"],
)
# This stand-alone target is used to ensure that we can build tf_library type
# targets against the subset of sources declared in
# xla_compiled_cpu_runtime_{srcs,hdrs}.
#
# The macros in tensorflow/python/tools/tools.bzl produce AOT compiled binaries
# that rely on this target, as do unit tests in tensorflow/python/tools.
#
# See above for the significance of the source filegroups.
cc_library(
name = "xla_compiled_cpu_runtime_standalone",
srcs = [
":xla_compiled_cpu_runtime_srcs",
],
hdrs = [
":xla_compiled_cpu_runtime_hdrs",
],
copts = runtime_copts() + tf_openmp_copts(),
features = ["fully_static_link"],
linkstatic = 1,
visibility = [":friends"],
# Note, we specifically remove MKL dependencies so the standalone does
# not require the MKL binary blob.
deps = [
"//tensorflow/core/framework:numeric_types",
"//third_party/eigen3",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/synchronization",
],
alwayslink = 1,
)
cc_library(
name = "xla_compiled_cpu_function",
srcs = ["xla_compiled_cpu_function.cc"],
@ -190,7 +261,7 @@ cc_library(
# binary produced by tfcompile.
"//tensorflow/compiler/xla:cpu_function_runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//tensorflow/core/platform:types",
],
)

View File

@ -36,6 +36,25 @@ filegroup(
]),
)
filegroup(
name = "cpu_runtime_srcs",
srcs = [
"cpu_function_runtime.cc",
"executable_run_options.cc",
],
visibility = [":friends"],
)
filegroup(
name = "cpu_runtime_hdrs",
srcs = [
"cpu_function_runtime.h",
"executable_run_options.h",
"types.h",
],
visibility = [":friends"],
)
tf_proto_library_cc(
name = "xla_data_proto",
srcs = ["xla_data.proto"],
@ -142,7 +161,8 @@ cc_library(
hdrs = ["types.h"],
visibility = [":friends"],
deps = [
"//tensorflow/core:framework_lite",
"//tensorflow/core/framework:numeric_types",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
],
)
@ -620,7 +640,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":types",
"@com_google_absl//absl/strings",
],
)
@ -896,7 +915,10 @@ cc_library(
srcs = ["cpu_function_runtime.cc"],
hdrs = ["cpu_function_runtime.h"],
visibility = [":friends"],
deps = ["//tensorflow/core:framework_lite"],
deps = [
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:types",
],
)
tf_cc_test(

View File

@ -271,8 +271,7 @@ static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
absl::Span<Shape const* const> argument_host_shapes,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
ExecutableRunOptions run_options) {
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
if (argument_host_shapes.size() != arguments.size()) {
return InvalidArgument(
"Number of argument host shapes not equal to number of arguments (%d "
@ -291,8 +290,8 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
shaped_buffer_ptrs.reserve(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
*argument_host_shapes[i], arguments[i], backend_->platform(),
stream->parent()->device_ordinal()));
*argument_host_shapes[i], arguments[i].Buffers(),
backend_->platform(), stream->parent()->device_ordinal()));
shaped_buffer_ptrs.push_back(&shaped_buffers.back());
}

View File

@ -61,8 +61,7 @@ class LocalExecutable {
// executable.
StatusOr<ExecutionOutput> RunAsync(
absl::Span<Shape const* const> argument_host_shapes,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
ExecutableRunOptions run_options);
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
// Return the options used to build the executable.
const ExecutableBuildOptions& build_options() const { return build_options_; }
@ -76,8 +75,8 @@ class LocalExecutable {
//
// The given ExecutableRunOptions override any values from TF_XLA_FLAGS
// environment variable.
Status ValidateExecutionOptions(
const ExecutableRunOptions& run_options, const Backend& backend);
Status ValidateExecutionOptions(const ExecutableRunOptions& run_options,
const Backend& backend);
// Returns a literal containing the contents of the given ShapedBuffer.
StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);

View File

@ -70,45 +70,6 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator,
entry->set_id(id);
entry->set_name(GetFullName(base_name, separator, id));
}
template <typename InstructionType>
StatusOr<InstructionType> LookUpInstructionByHandleInternal(
const absl::flat_hash_map<int64, int64>& handle_to_index,
const std::vector<HloInstructionProto>& instructions, int64 handle) {
auto it = handle_to_index.find(handle);
if (it == handle_to_index.end()) {
return InvalidArgument("No XlaOp with handle %d", handle);
}
return const_cast<InstructionType>(&instructions.at(it->second));
}
Status CheckBuildersAffinity(const XlaBuilder* op_builder,
const XlaBuilder* builder, int64 handle) {
if (op_builder != builder) {
return InvalidArgument(
"XlaOp with handle %d is built by builder '%s', but is trying to use "
"it in builder '%s'",
handle, op_builder->name(), builder->name());
}
return Status::OK();
}
template <typename InstructionType, typename OpBuilderType,
typename BuilderType, typename OpType>
StatusOr<InstructionType> LookUpInstructionInternal(
const absl::flat_hash_map<int64, int64>& handle_to_index,
const std::vector<HloInstructionProto>& instructions,
OpBuilderType op_builder, BuilderType builder, OpType op_handle) {
if (op_builder == nullptr) {
return InvalidArgument(
"Invalid XlaOp with handle %d; the builder of this op is freed",
op_handle);
}
TF_RETURN_IF_ERROR(CheckBuildersAffinity(op_builder, builder, op_handle));
return LookUpInstructionByHandleInternal<InstructionType>(
handle_to_index, instructions, op_handle);
}
} // namespace
XlaOp operator-(XlaOp x) { return Neg(x); }
@ -143,7 +104,7 @@ XlaOp operator>>(XlaOp x, XlaOp y) {
StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
TF_RETURN_IF_ERROR(first_error_);
TF_RETURN_IF_ERROR(CheckBuildersAffinity(op.builder(), this, op.handle()));
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
auto it = handle_to_index_.find(op.handle());
if (it == handle_to_index_.end()) {
return InvalidArgument("No XlaOp with handle %d", op.handle());
@ -1941,6 +1902,16 @@ XlaOp XlaBuilder::ConditionalImpl(
});
}
Status XlaBuilder::CheckOpBuilder(XlaOp op) const {
if (this != op.builder()) {
return InvalidArgument(
"XlaOp with handle %d is built by builder '%s', but is trying to use "
"it in builder '%s'",
op.handle(), op.builder()->name(), name());
}
return Status::OK();
}
XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) {
@ -2886,27 +2857,23 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
const XlaOp op) const {
TF_RETURN_IF_ERROR(first_error_);
return LookUpInstructionInternal<const HloInstructionProto*>(
handle_to_index_, instructions_, op.builder_, this, op.handle());
return LookUpInstructionInternal<const HloInstructionProto*>(op);
}
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
int64 handle) const {
return LookUpInstructionByHandleInternal<const HloInstructionProto*>(
handle_to_index_, instructions_, handle);
return LookUpInstructionByHandleInternal<const HloInstructionProto*>(handle);
}
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
const XlaOp op) {
TF_RETURN_IF_ERROR(first_error_);
return LookUpInstructionInternal<HloInstructionProto*>(
handle_to_index_, instructions_, op.builder_, this, op.handle());
return LookUpInstructionInternal<HloInstructionProto*>(op);
}
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
int64 handle) {
return LookUpInstructionByHandleInternal<HloInstructionProto*>(
handle_to_index_, instructions_, handle);
return LookUpInstructionByHandleInternal<HloInstructionProto*>(handle);
}
// Enqueues a "retrieve parameter value" instruction for a parameter that was

View File

@ -1061,6 +1061,34 @@ class XlaBuilder {
XlaOp branch_index,
absl::Span<const XlaComputation* const> branch_computations,
absl::Span<const XlaOp> branch_operands);
// Returns OK status if the given op was built using this builder. Otherwise,
// returns an error.
Status CheckOpBuilder(XlaOp op) const;
// Here, InstructionType is either const HloInstructionProto* or non-const
// HloInstructionProto*.
template <typename InstructionType>
StatusOr<InstructionType> LookUpInstructionByHandleInternal(
int64 handle) const {
auto it = handle_to_index_.find(handle);
if (it == handle_to_index_.end()) {
return InvalidArgument("No XlaOp with handle %d", handle);
}
return const_cast<InstructionType>(&instructions_.at(it->second));
}
// Here, InstructionType is either const HloInstructionProto* or non-const
// HloInstructionProto*.
//
// TODO(hinsu): Return const pointer within StatusOr and use
// absl::implicit_cast at callsites. This requires implicit_cast support in
// stream_executor::port::StatusOr similar to absl::StatusOr.
template <typename InstructionType>
StatusOr<InstructionType> LookUpInstructionInternal(XlaOp op) const {
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
return LookUpInstructionByHandleInternal<InstructionType>(op.handle());
}
};
// RAII-style object: sets the current sharding assignment in builder on

View File

@ -17,8 +17,6 @@ limitations under the License.
#include <atomic>
#include "absl/strings/str_cat.h"
namespace xla {
RunId::RunId() {
@ -28,7 +26,9 @@ RunId::RunId() {
bool operator==(const RunId& a, const RunId& b) { return a.data_ == b.data_; }
std::string RunId::ToString() const { return absl::StrCat("RunId: ", data_); }
std::string RunId::ToString() const {
return "RunId: " + std::to_string(data_);
}
ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
int device_ordinal) {

1
tensorflow/compiler/xla/service/BUILD Executable file → Normal file
View File

@ -3109,6 +3109,7 @@ cc_library(
deps = [
":heap_simulator",
":hlo_cost_analysis",
"//tensorflow/compiler/xla:debug_options_flags",
],
)

View File

@ -30,6 +30,32 @@ filegroup(
]),
)
filegroup(
name = "single_threaded_runtime_srcs",
srcs = [
"runtime_fp16.cc",
"runtime_key_value_sort.cc",
"runtime_single_threaded_conv2d.cc",
"runtime_single_threaded_fft.cc",
"runtime_single_threaded_matmul.cc",
],
visibility = [":friends"],
)
filegroup(
name = "single_threaded_runtime_hdrs",
srcs = [
"runtime_conv2d_impl.h",
"runtime_fft_impl.h",
"runtime_fp16.h",
"runtime_key_value_sort.h",
"runtime_single_threaded_conv2d.h",
"runtime_single_threaded_fft.h",
"runtime_single_threaded_matmul.h",
],
visibility = [":friends"],
)
cc_library(
name = "cpu_transfer_manager",
srcs = ["cpu_transfer_manager.cc"],
@ -219,7 +245,8 @@ cc_library(
],
copts = runtime_copts(),
deps = [
"//tensorflow/core:framework_lite",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:types",
],
)
@ -545,8 +572,10 @@ cc_library(
deps = [
":runtime_lightweight_check",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//tensorflow/core/kernels:eigen_helpers",
"//tensorflow/core/kernels:eigen_helpers_no_mkl",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:mutex",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
],
)
@ -563,7 +592,8 @@ cc_library(
":runtime_conv2d",
":runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:types",
"//tensorflow/core/kernels:eigen_helpers",
"//third_party/eigen3",
] + mkl_deps(),
@ -581,8 +611,10 @@ cc_library(
deps = [
":runtime_lightweight_check",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:framework_lite",
"//tensorflow/core/framework:numeric_types",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:mutex",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
],
)
@ -596,8 +628,10 @@ cc_library(
deps = [
":runtime_lightweight_check",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//tensorflow/core/kernels:eigen_contraction_kernel",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:mutex",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
],
)
@ -610,7 +644,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
] + mkl_deps(),
)
@ -626,8 +660,9 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":runtime_lightweight_check",
"//tensorflow/core:framework_lite",
"//tensorflow/core/kernels:eigen_helpers",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
],
)
@ -643,7 +678,9 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:framework_lite",
"//tensorflow/core/framework:numeric_types",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
],
)
@ -655,8 +692,9 @@ cc_library(
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_lite",
"//tensorflow/core/kernels:eigen_contraction_kernel",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
],
)
@ -668,7 +706,9 @@ cc_library(
copts = runtime_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_lite",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:macros",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
],
)
@ -681,8 +721,10 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/platform:blocking_counter",
"//tensorflow/core/platform:dynamic_annotations",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:types",
"//third_party/eigen3",
],
)
@ -711,6 +753,23 @@ tf_cc_test(
],
)
tf_cc_test(
name = "runtime_fft_test",
srcs = [
"runtime_fft_impl.h",
"runtime_fft_test.cc",
],
deps = [
":runtime_single_threaded_fft",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core/framework:numeric_types",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "cpu_instruction_fusion_test",
srcs = ["cpu_instruction_fusion_test.cc"],

View File

@ -78,9 +78,9 @@ CpuExecutable::CpuExecutable(
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
std::vector<se::OwningDeviceMemory>,
std::vector<se::OwningDeviceMemory>>>
CpuExecutable::CreateBufferTable(
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments) {
CpuExecutable::CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
int device_ordinal,
std::vector<ExecutionInput> arguments) {
std::vector<se::DeviceMemoryBase> unowning_buffers(
assignment_->Allocations().size());
std::vector<se::OwningDeviceMemory> owning_buffers(
@ -95,7 +95,7 @@ CpuExecutable::CreateBufferTable(
if (allocation.is_entry_computation_parameter()) {
unowning_buffers[i] = arguments[allocation.parameter_number()]
.element(allocation.param_shape_index())
.Buffer(allocation.param_shape_index())
.AsDeviceMemoryBase();
CHECK_EQ(allocation.size(), unowning_buffers[i].size())
<< "Size mismatch on param " << allocation.parameter_number()
@ -139,9 +139,9 @@ CpuExecutable::CreateBufferTable(
VLOG(3) << "result index: " << result_slice.index();
std::vector<se::OwningDeviceMemory> buffers_to_free;
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
auto maybe_owning_buffer = buffer.second.Release();
for (auto& argument : arguments) {
for (auto& index_buffer : *argument.MutableBuffers()) {
auto maybe_owning_buffer = index_buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
}
@ -284,7 +284,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
if (GetRootValueSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
@ -297,7 +297,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
const Shape& expected_shape =
entry_comp->parameter_instruction(i)->shape();
const Shape& actual_shape = arguments[i].shape();
const Shape& actual_shape = arguments[i].Buffers().shape();
CHECK(
Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape))
<< absl::StreamFormat(
@ -355,9 +355,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
std::make_shared<std::vector<se::OwningDeviceMemory>>(
std::move(owning_buffers)),
hlo_execution_profile});
return ExecutionOutput(std::move(result), std::move(buffers_to_release), {},
se::OwningDeviceMemory());
return ExecutionOutput(std::move(result), std::move(buffers_to_release));
}
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {

View File

@ -57,7 +57,7 @@ class CpuExecutable : public Executable {
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) override;
// This should be called after set_ir_module_string.
@ -103,8 +103,7 @@ class CpuExecutable : public Executable {
std::vector<se::OwningDeviceMemory>,
std::vector<se::OwningDeviceMemory>>>
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
int device_ordinal,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments);
int device_ordinal, std::vector<ExecutionInput> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.

View File

@ -33,7 +33,8 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft(
const xla::ExecutableRunOptions* run_options =
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
tensorflow::xla::EigenFftImpl(*run_options->intra_op_thread_pool(), out,
operand, fft_type, fft_rank, input_batch,
fft_length0, fft_length1, fft_length2);
tensorflow::xla::EigenFftImpl(
*run_options->intra_op_thread_pool(), out, operand,
static_cast<tensorflow::xla::FftType>(fft_type), fft_rank, input_batch,
fft_length0, fft_length1, fft_length2);
}

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h"
@ -28,6 +27,15 @@ limitations under the License.
namespace tensorflow {
namespace xla {
enum class FftType : int32 {
FFT = 0, // Forward FFT; complex in, complex out.
IFFT = 1, // Inverse FFT; complex in, complex out.
RFFT = 2, // Forward real FFT; real in, fft_length / 2 + 1 complex out
IRFFT = 3, // Inverse real FFT; fft_length / 2 + 1 complex in,
// fft_length real out
};
static constexpr int kFftTypeArraySize = 4;
namespace internal {
// Computes either a forward or reverse complex-to-complex FFT.
@ -170,27 +178,27 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
template <int FFTRank, typename EigenDevice>
void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
int32 fft_type, int64 input_batch, int64 fft_length0,
FftType fft_type, int64 input_batch, int64 fft_length0,
int64 fft_length1, int64 fft_length2) {
switch (fft_type) {
case ::xla::FftType::FFT:
case FftType::FFT:
EigenFftC2C<true, FFTRank, EigenDevice>(
device, static_cast<complex64*>(out),
static_cast<complex64*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
break;
case ::xla::FftType::IFFT:
case FftType::IFFT:
EigenFftC2C<false, FFTRank, EigenDevice>(
device, static_cast<complex64*>(out),
static_cast<complex64*>(operand), input_batch, fft_length0,
fft_length1, fft_length2);
break;
case ::xla::FftType::RFFT:
case FftType::RFFT:
EigenFftR2C<FFTRank, EigenDevice>(
device, static_cast<complex64*>(out), static_cast<float*>(operand),
input_batch, fft_length0, fft_length1, fft_length2);
break;
case ::xla::FftType::IRFFT:
case FftType::IRFFT:
EigenFftC2R<FFTRank, EigenDevice>(
device, static_cast<float*>(out), static_cast<complex64*>(operand),
input_batch, fft_length0, fft_length1, fft_length2);
@ -205,7 +213,7 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
template <typename EigenDevice>
void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
int32 fft_type, int32 fft_rank, int64 input_batch,
FftType fft_type, int32 fft_rank, int64 input_batch,
int64 fft_length0, int64 fft_length1, int64 fft_length2) {
switch (fft_rank) {
case 1:

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
TEST(FftTypeTest, MatchesProto) {
EXPECT_EQ(::xla::FftType_ARRAYSIZE, 4);
EXPECT_EQ(::tensorflow::xla::kFftTypeArraySize, 4);
EXPECT_EQ(::xla::FftType::FFT,
static_cast<::tensorflow::int32>(::tensorflow::xla::FftType::FFT));
EXPECT_EQ(::xla::FftType::IFFT,
static_cast<::tensorflow::int32>(::tensorflow::xla::FftType::IFFT));
EXPECT_EQ(::xla::FftType::RFFT,
static_cast<::tensorflow::int32>(::tensorflow::xla::FftType::RFFT));
EXPECT_EQ(::xla::FftType::IRFFT, static_cast<::tensorflow::int32>(
::tensorflow::xla::FftType::IRFFT));
}

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

View File

@ -26,7 +26,8 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft(
const void* run_options_ptr, void* out, void* operand, int32 fft_type,
int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1,
int64 fft_length2) {
tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type,
tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand,
static_cast<tensorflow::xla::FftType>(fft_type),
fft_rank, input_batch, fft_length0, fft_length1,
fft_length2);
}

View File

@ -44,15 +44,13 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
return result;
}
static ShapeTree<MaybeOwningDeviceMemory> MakeMaybeOwningDeviceMemoryTree(
static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
const ShapedBuffer& shaped_buffer) {
ShapeTree<MaybeOwningDeviceMemory> result(shaped_buffer.on_device_shape());
auto in_it = shaped_buffer.buffers().begin();
auto out_it = result.begin();
for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) {
DCHECK(out_it != result.end());
out_it->second = MaybeOwningDeviceMemory(in_it->second);
}
ExecutionInput result(shaped_buffer.on_device_shape());
shaped_buffer.buffers().ForEachElement(
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
result.SetBuffer(index, MaybeOwningDeviceMemory(mem));
});
return result;
}
@ -60,7 +58,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
absl::Span<const ShapedBuffer* const> arguments,
HloExecutionProfile* hlo_execution_profile) {
std::vector<ShapeTree<MaybeOwningDeviceMemory>> args(arguments.size());
std::vector<ExecutionInput> args(arguments.size());
auto out_it = args.begin();
for (const ShapedBuffer* arg : arguments) {
*out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
@ -73,7 +71,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream(
run_options, std::move(arguments), hlo_execution_profile);
@ -238,7 +236,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStreamWrapper(
StatusOr<ExecutionOutput> Executable::ExecuteAsyncOnStreamWrapper(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments) {
std::vector<ExecutionInput> arguments) {
auto state = ExecuteWrapperBeforeExecution(*this, run_options);
StatusOr<ExecutionOutput> return_value = ExecuteAsyncOnStream(
run_options, std::move(arguments), state.profile_ptr.get());

View File

@ -42,18 +42,75 @@ limitations under the License.
namespace xla {
// TODO(b/150633678): Both the ExecutionInput and ExecutionOutput need to be
// revisited, with the execute APIs taking data structure which can better model
// shareable buffers.
class ExecutionInput {
public:
ExecutionInput() = default;
explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {}
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers)
: buffers_(std::move(buffers)) {}
ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
std::vector<ShapeIndex> owner_held_indices)
: buffers_(std::move(buffers)),
unowned_indices_(std::move(owner_held_indices)) {}
ExecutionInput(ExecutionInput&&) = default;
~ExecutionInput() {
for (auto& index : unowned_indices_) {
auto buffer = buffers_.mutable_element(index)->Release();
if (buffer) {
buffer->Release();
}
}
}
ExecutionInput& operator=(ExecutionInput&&) = default;
const Shape& shape() const { return buffers_.shape(); }
void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) {
*buffers_.mutable_element(index) = std::move(buffer);
}
void SetUnownedBuffer(const ShapeIndex& index,
MaybeOwningDeviceMemory buffer) {
*buffers_.mutable_element(index) = std::move(buffer);
unowned_indices_.push_back(index);
}
const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; }
ShapeTree<MaybeOwningDeviceMemory>* MutableBuffers() { return &buffers_; }
MaybeOwningDeviceMemory* MutableBuffer(const ShapeIndex& index) {
return buffers_.mutable_element(index);
}
const MaybeOwningDeviceMemory& Buffer(const ShapeIndex& index) const {
return buffers_.element(index);
}
private:
ShapeTree<MaybeOwningDeviceMemory> buffers_;
std::vector<ShapeIndex> unowned_indices_;
};
// ExecutionOutput encapsulates the output buffers of a execution and the
// leftover buffers to be released by the caller.
class ExecutionOutput {
public:
explicit ExecutionOutput(ScopedShapedBuffer result)
: result_(std::move(result)) {}
ExecutionOutput(ScopedShapedBuffer result,
std::vector<se::OwningDeviceMemory> to_be_released,
std::vector<ShapeIndex> aliased_indices,
se::OwningDeviceMemory output_shape_table)
std::vector<se::OwningDeviceMemory> to_be_released)
: result_(std::move(result)),
to_be_released_(std::move(to_be_released)),
aliased_indices_(std::move(aliased_indices)),
output_shape_table_(std::move(output_shape_table)) {}
to_be_released_(std::move(to_be_released)) {}
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
se::DeviceMemoryAllocator* allocator, int device_ordinal)
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
device_ordinal) {}
ExecutionOutput(ExecutionOutput&&) = default;
ExecutionOutput& operator=(ExecutionOutput&&) = default;
@ -66,6 +123,18 @@ class ExecutionOutput {
}
}
void AddAliasedIndex(ShapeIndex index) {
aliased_indices_.push_back(std::move(index));
}
void AddToBeReleased(se::OwningDeviceMemory mem) {
to_be_released_.push_back(std::move(mem));
}
void SetOutputShapeTable(se::OwningDeviceMemory output_shape_table) {
output_shape_table_ = std::move(output_shape_table);
}
// Should be called once it is known that the execute operation succeeded,
// before returning the ExecutionOutput to the caller.
ExecutionOutput& Commit() {
@ -75,6 +144,8 @@ class ExecutionOutput {
const ScopedShapedBuffer& Result() const { return result_; }
ScopedShapedBuffer* MutableResult() { return &result_; }
const se::OwningDeviceMemory& ShapeTable() const {
return output_shape_table_;
}
@ -169,12 +240,12 @@ class Executable {
// complete.
StatusOr<ExecutionOutput> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile);
virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) = 0;
// Same as ExecuteOnStream(), but runs this executable on multiple
@ -208,7 +279,7 @@ class Executable {
StatusOr<ExecutionOutput> ExecuteAsyncOnStreamWrapper(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments);
std::vector<ExecutionInput> arguments);
const HloProfilePrinterData& hlo_profile_printer_data() const {
CHECK(hlo_profiling_enabled());

0
tensorflow/compiler/xla/service/gpu/BUILD Executable file → Normal file
View File

View File

@ -328,7 +328,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
XLA_SCOPED_LOGGING_TIMER(absl::StrCat("GpuExecutable::ExecuteAsyncOnStream(",
module().name(), ")"));
@ -367,7 +367,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
auto param_no = allocation.parameter_number();
se::DeviceMemoryBase buffer =
arguments[param_no]
.element(allocation.param_shape_index())
.Buffer(allocation.param_shape_index())
.AsDeviceMemoryBase();
// All top-level buffers and sub-buffers must have an explicit, non-null
@ -458,16 +458,15 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
std::vector<se::OwningDeviceMemory> buffers_to_free;
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
auto maybe_owning_buffer = buffer.second.Release();
for (auto& argument : arguments) {
for (auto& index_buffer : *argument.MutableBuffers()) {
auto maybe_owning_buffer = index_buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
}
}
}
return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free),
{}, {});
return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free));
}
const InstructionValueSet& GpuExecutable::GetRootValueSet() const {

View File

@ -84,7 +84,7 @@ class GpuExecutable : public Executable {
// doesn't match the compute capability passed to this object's constructor.
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) override;
std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {

View File

@ -520,7 +520,7 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)),
b->CreateICmpEQ(
b->getInt32(0),
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)));
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b)));
}
bool AreFusedReductionOutputsConsistent(

View File

@ -2156,9 +2156,9 @@ void IrEmitterUnnested::EmitPrologueForReduction(
reduce_inst->shape().element_type(), module_);
llvm::Type* buffer_type = [&] {
if (reduction_info->IsRowReduction()) {
// Allocate __shared__ cache[num_partial_results][num_threads].
// Allocate __shared__ cache[num_partial_results][kWarpSize].
return llvm::ArrayType::get(
llvm::ArrayType::get(primitive_type, num_threads_x),
llvm::ArrayType::get(primitive_type, kWarpSize),
num_partial_results);
} else {
// Allocate __shared__

View File

@ -815,6 +815,30 @@ ENTRY %primitive_computation_svd.38 (constant_5: f32[3,29,29], fusion.3: pred[3]
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
}
TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) {
const char *const kHloString = R"(
HloModule RowReduce
Sum {
x.1 = f32[] parameter(0)
y.1 = f32[] parameter(1)
ROOT add.1 = f32[] add(x.1, y.1)
}
ENTRY reduce.1 {
parameter = f32[1048576] parameter(0)
init_value = f32[] constant(0)
ROOT reduce = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum
}
)";
auto hlo_module = ParseAndReturnVerifiedModule(kHloString).ValueOrDie();
auto expected_ir = R"(
; CHECK: shared_cache_{{[0-9]*}} = private addrspace({{[0-9]*}}) global [1 x [32 x float]]
)";
CompileAndVerifyIr(std::move(hlo_module), expected_ir,
/*match_optimized_ir=*/true);
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -412,6 +412,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
absl::flat_hash_map<const HloValue*, BufferInterval> buffer_intervals_;
Result result_;
BufferIntervalCompare buffer_interval_compare_;
BufferIntervalTree interval_tree_;
private:
int64 alignment_;
@ -420,8 +421,6 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
// Alloc or Free call.
int64 current_time_ = 0;
BufferIntervalTree interval_tree_;
// Returns all transitive colocated buffers of this buffer interval. I.e., If
// a buffer A is colocated with B and B is colocated with C, this function
// returns all three of them.

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@ -109,8 +110,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
TEST_F(HloConstantFoldingTest, Concatenate) {
const struct TestConfig {
int concat_dimension;
absl::Span<const int64> dimensions;
absl::Span<const int64> concat_sizes;
std::vector<int64> dimensions;
std::vector<int64> concat_sizes;
} test_configs[] = {
{1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
{3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},

0
tensorflow/compiler/xla/service/hlo_instruction.cc Executable file → Normal file
View File

0
tensorflow/compiler/xla/service/hlo_instructions.h Executable file → Normal file
View File

View File

@ -717,7 +717,10 @@ HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
uint64 HloModule::Hash() const {
uint64 result = entry_computation_layout().Hash();
for (auto* computation : MakeComputationPostOrder()) {
// Use MakeComputationSortedByContent() instead of MakeComputationPostOrder()
// because naming may affect the order of MakeComputationPostOrder() but not
// MakeComputationSortedByContent().
for (auto* computation : MakeComputationSortedByContent()) {
for (auto* instruction : computation->MakeInstructionPostOrder()) {
result = tensorflow::Hash64Combine(result, instruction->Hash());
}

View File

@ -196,6 +196,10 @@ class HloModule {
// computation B, then A will appear after B in the sort.
std::vector<HloComputation*> MakeComputationPostOrder() const;
// Same as MakeComputationPostOrder() but sorting the computations by their
// contents.
std::vector<HloComputation*> MakeComputationSortedByContent() const;
// Gets the computations in this module which aren't for fusion nodes.
//
// Postcondition: All computations in the returned list have
@ -346,10 +350,6 @@ class HloModule {
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_identifiers);
// Same as MakeComputationPostOrder() but sorting the computations by their
// contents.
std::vector<HloComputation*> MakeComputationSortedByContent() const;
string name_;
HloModuleConfig config_;
HloComputation* entry_computation_ = nullptr;

View File

@ -104,9 +104,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
VLOG(1) << "Run backend " << hlo_module->name();
// Typically you would visit the HLO graph, building up a compiled equivalent
// In this case we are using an HloEvaluator at execution time, so we don't
// need to compile anything
TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
DynamicDimensionInference::Run(hlo_module.get()));
auto evaluator = absl::make_unique<HloEvaluator>();
evaluator->set_use_fast_path(
@ -115,8 +114,9 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
// Create executable from only the Hlo module.
std::unique_ptr<Executable> executable =
absl::make_unique<InterpreterExecutable>(std::move(hlo_module),
std::move(evaluator));
absl::make_unique<InterpreterExecutable>(
std::move(hlo_module), std::move(evaluator),
std::move(dynamic_dimension_inference));
return std::move(executable);
}

View File

@ -39,16 +39,23 @@ namespace interpreter {
InterpreterExecutable::InterpreterExecutable(
std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloEvaluator> evaluator)
std::unique_ptr<HloEvaluator> evaluator,
absl::optional<DynamicDimensionInference> dynamic_dymension_inference)
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
/*hlo_profile_index_map=*/nullptr),
evaluator_(std::move(evaluator)) {}
evaluator_(std::move(evaluator)),
dynamic_dimension_inference_(std::move(dynamic_dymension_inference)) {
if (dynamic_dimension_inference_.has_value()) {
evaluator_->set_dynamic_dimension_inference(
&dynamic_dimension_inference_.value());
}
}
InterpreterExecutable::~InterpreterExecutable() {}
StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
se::Stream* stream = run_options->stream();
se::StreamExecutor* executor = stream->parent();
@ -58,13 +65,14 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
// TransferManager methods below.
std::vector<ShapedBuffer> argument_buffers;
argument_buffers.reserve(arguments.size());
for (const ShapeTree<MaybeOwningDeviceMemory>& arg : arguments) {
argument_buffers.push_back(ShapedBuffer(arg.shape(), arg.shape(),
for (auto& argument : arguments) {
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(),
/*platform=*/nullptr,
/*device_ordinal=*/0));
auto in_it = arg.begin();
auto in_it = buffers.begin();
auto out_it = argument_buffers.back().buffers().begin();
for (; in_it != arg.end(); ++in_it, ++out_it) {
for (; in_it != buffers.end(); ++in_it, ++out_it) {
out_it->second = in_it->second.AsDeviceMemoryBase();
}
}
@ -120,12 +128,13 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
}
// Transform the result literal back into a ShapedBuffer.
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers,
transfer_manager->AllocateScopedShapedBuffer(
result_literal.shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
run_options->stream(), result_literal, result));
run_options->stream(), result_literal, result_buffers));
ExecutionOutput result(std::move(result_buffers));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@ -134,17 +143,15 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
const double nanoseconds = (end_micros - start_micros) * 1000.0;
profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
}
std::vector<se::OwningDeviceMemory> buffers_to_free;
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
auto maybe_owning_buffer = buffer.second.Release();
for (auto& argument : arguments) {
for (auto& index_buffer : *argument.MutableBuffers()) {
auto maybe_owning_buffer = index_buffer.second.Release();
if (maybe_owning_buffer) {
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
result.AddToBeReleased(std::move(*maybe_owning_buffer));
}
}
}
return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {});
return std::move(result);
}
/*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {

View File

@ -42,13 +42,15 @@ namespace interpreter {
// buffer allocation. Refer to interpreter/README.md for more.
class InterpreterExecutable : public Executable {
public:
InterpreterExecutable(std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloEvaluator> evaluator);
InterpreterExecutable(
std::unique_ptr<HloModule> hlo_module,
std::unique_ptr<HloEvaluator> evaluator,
absl::optional<DynamicDimensionInference> dynamic_dymension_inference);
~InterpreterExecutable() override;
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) override
LOCKS_EXCLUDED(evaluator_lock_);
@ -60,6 +62,7 @@ class InterpreterExecutable : public Executable {
mutable tensorflow::mutex evaluator_lock_;
private:
absl::optional<DynamicDimensionInference> dynamic_dimension_inference_;
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
};

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
namespace xla {
namespace {
@ -380,6 +382,13 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
continue;
}
if (!ConsumeFuel("memory_space_assignment", [&] {
return absl::StrCat("Ran out of fuel at buffer: ",
colocated_intervals[0]->buffer->ToShortString());
})) {
continue;
}
const HloComputation* defining_computation =
colocated_intervals[0]->buffer->defining_instruction()->parent();
MemorySpaceAssignment::Allocation* aliased_allocation = nullptr;
@ -458,8 +467,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
// If the allocation finding failed (e.g., due to running out of
// asynchronous copies), then fall back to allocating the buffer
// entirely in the default memory.
pending_chunks_.clear();
pending_async_copies_.clear();
UncommitPendingChunks();
allocation_sequence->clear();
break;
}
@ -478,7 +486,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
}
}
CommitPendingChunks();
pending_chunks_.clear();
pending_async_copies_.clear();
}
if (VLOG_IS_ON(3)) {
@ -510,6 +519,12 @@ void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
it_and_inserted.first->start_time == copy.start_time);
}
void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
auto copy_it = ranges_.find(copy);
CHECK(copy_it != ranges_.end());
ranges_.erase(copy_it);
}
bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time,
int64 end_time) const {
// We allow identical start and end times. It is enough to check for just the
@ -620,32 +635,31 @@ bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
return false;
}
void AlternateMemoryBestFitHeap::CommitPendingChunks() {
void AlternateMemoryBestFitHeap::UncommitPendingChunks() {
for (auto interval_and_chunk : pending_chunks_) {
VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-"
<< interval_and_chunk.first.end << " : ["
<< interval_and_chunk.second.chunk.offset << ", "
<< interval_and_chunk.second.chunk.size << "]";
CommitChunk(interval_and_chunk.first, interval_and_chunk.second);
const BufferInterval& interval = interval_and_chunk.first;
const Chunk& chunk = interval_and_chunk.second.chunk;
interval_tree_.Remove(interval.start, interval.end, chunk);
}
for (const auto& interval : pending_async_copies_) {
async_copy_interval_tree_.Remove(interval.start_time, interval.end_time,
kDummyChunk);
if (interval.destination == MemorySpace::kAlternate) {
async_copy_ordering_.RemoveCopy(interval);
}
}
pending_chunks_.clear();
// Also add the pending async copies to the interval tree.
for (const auto& interval : pending_async_copies_) {
if (options_.max_outstanding_async_copies >= 0) {
async_copy_interval_tree_.Add(interval.start_time, interval.end_time,
kDummyChunk);
}
if (interval.destination == MemorySpace::kAlternate) {
async_copy_ordering_.AddCopy(interval);
}
}
pending_async_copies_.clear();
}
void AlternateMemoryBestFitHeap::AddToPendingChunks(
const BufferInterval& buffer_interval,
const ChunkCandidate& chunk_candidate) {
VLOG(3) << "Committing chunk: " << buffer_interval.start << "-"
<< buffer_interval.end << " : [" << chunk_candidate.chunk.offset
<< ", " << chunk_candidate.chunk.size << "]";
pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
CommitChunk(buffer_interval, chunk_candidate);
}
bool AlternateMemoryBestFitHeap::RequiredInDefaultMemory(const HloValue* buffer,
@ -772,6 +786,10 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy(
// Register the additional async copy with the interval tree to keep track of
// the limit at any given time.
pending_async_copies_.push_back({start_time, end_time, memory_space});
async_copy_interval_tree_.Add(start_time, end_time, kDummyChunk);
if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) {
async_copy_ordering_.AddCopy(pending_async_copies_.back());
}
}
bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
@ -780,17 +798,11 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
return false;
}
// Count both the asynchronous copies in the interval tree as well as the
// pending asynchronous copies belonging to this buffer.
// Count the asynchronous copies in the interval tree for the given interval.
int64 num_async_copies =
async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
.size();
for (const auto& interval : pending_async_copies_) {
if (interval.start_time > start_time && interval.end_time < end_time) {
num_async_copies++;
}
}
// Add one because we are checking if adding an additional asynchronous copy
// would violate the limit.
return num_async_copies + 1 > options_.max_outstanding_async_copies;
@ -798,19 +810,7 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
int64 start_time, int64 end_time) const {
if (async_copy_ordering_.ViolatesOrdering(start_time, end_time)) {
return true;
}
// Also check pending async copies.
for (const auto& async_copy : pending_async_copies_) {
if (async_copy.destination == MemorySpace::kAlternate &&
async_copy.start_time <= end_time &&
start_time <= async_copy.end_time) {
return true;
}
}
return false;
return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
}
bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
@ -1048,7 +1048,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval);
// Check if the new heap size fits within limits.
if (chunk_candidate.heap_size < available_heap_size()) {
if (chunk_candidate.heap_size <= available_heap_size()) {
VLOG(3) << "Move the buffer to alternate memory at "
<< alternate_mem_interval.start
<< ". Offset = " << chunk_candidate.chunk.offset

View File

@ -580,6 +580,9 @@ class AsynchronousCopyOrdering {
// Adds an asynchronous copy.
void AddCopy(const AsynchronousCopy& copy);
// Removes an asynchronous copy. CHECKs that it is removed.
void RemoveCopy(const AsynchronousCopy& copy);
// Returns true if the addition of an asynchronous copy in the the given time
// interval would violate the asynchronous copy ordering. E.g., consider the
// following scenario:
@ -735,11 +738,15 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
int64 end_time, int64 copy_done_schedule_before_time,
MemorySpaceAssignment::AllocationSequence* allocations);
// These methods are used for delaying committing the chunk candidate until
// the entire live range of the buffer has been considered.
// This method is used for committing the chunk candidate but adding it to
// pending_chunks_ so that we can "uncommit" them in case we need to roll back
// this allocation sequence.
void AddToPendingChunks(const BufferInterval& buffer_interval,
const ChunkCandidate& chunk_candidate);
void CommitPendingChunks();
// If we need to remove the allocations for this allocation sequence, this
// removes pending chunks and asynchronous copies in the respective pending
// buffers from the interval trees.
void UncommitPendingChunks();
// Returns the available heap size in the alternate memory.
int64 available_heap_size() const {

View File

@ -2827,6 +2827,100 @@ TEST_P(MemorySpaceAssignmentTest,
}
}
TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) {
// Tests a memory corruption bug where the allocated chunk overlaps with a
// pending chunk. To test this, we provide a new buffer interval compare where
// we prioritize the allocation of sine, cosine, and tanh to create the
// situation:
//
// Max memory
// -------------------------------------------
// +------------+
// | b |
// +------------+
// +-------+
// | |
// | |
// | a |
// | | +------------+
// | | | n |
// +-------+ +------------+
// -------------------------------------------
// Min memory time ->
//
//
// Then allocating for buffer d, we have these two prefetch buffers
// overlapping:
//
// Max memory
// -------------------------------------------
// +------------+ +----------+
// | b | | prefetch |
// +------------+ | for o |
// +-------+ +---------+ |
// | | | | | |
// | | | | | |
// | a | | +----|-----+
// | | | prefetch| +------------+
// | | | for m | | n |
// +-------+ +---------+ +------------+
// -------------------------------------------
// Min memory time ->
//
absl::string_view hlo_string = R"(
HloModule bug, is_scheduled=true
ENTRY %Entry {
%param0 = f32[8,3] parameter(0)
%param1 = f32[2,4] parameter(1)
%a = f32[8,3] sine(%param0)
%b = f32[2,4] cosine(%param1)
%d = f32[8,3] tanh(%a)
%c = f32[8,3] negate(%a)
%e = f32[2,4] negate(%b)
%f = f32[2,4] negate(%e)
%g = f32[2,4] negate(%f)
%h = f32[2,4] negate(%g)
%i = f32[2,4] negate(%h)
%j = f32[2,4] negate(%i)
%k = f32[2,4] negate(%j)
%l = f32[2,4] negate(%k)
%m = f32[8,3] negate(%d)
%n = f32[2,4] sine(%l)
%o = f32[8,3] negate(%d)
%p = f32[2,4] negate(%n)
%q = f32[8,3] negate(%m)
ROOT %tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(%p, %q, %o)
}
)";
MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
[](const MemorySpaceAssignment::BufferInterval& a,
const MemorySpaceAssignment::BufferInterval& b) {
auto get_opcode_priority = [](const HloOpcode& opcode) {
switch (opcode) {
case HloOpcode::kSin:
return 0;
case HloOpcode::kCos:
return 1;
case HloOpcode::kTanh:
return 2;
default:
return 3;
}
};
return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
get_opcode_priority(b.buffer->defining_instruction()->opcode());
};
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
buffer_interval_compare, &prefetch_interval_picker);
}
TEST_P(MemorySpaceAssignmentTest, Determinism) {
// Run memory space assignment a few times to make sure every time it compiles
// to the same thing.

View File

@ -2513,6 +2513,18 @@ xla_test(
],
)
xla_test(
name = "get_dimension_size_test",
srcs = ["get_dimension_size_test.cc"],
deps = [
":hlo_test_base",
":test_macros_header",
":xla_internal_test_main", # fixdeps: keep
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:test",
],
)
xla_test(
name = "triangular_solve_test",
srcs = ["triangular_solve_test.cc"],

View File

@ -96,8 +96,8 @@ class BufferDonationTest : public HloTestBase {
memory_allocator.get());
});
std::vector<ShapeTree<MaybeOwningDeviceMemory>> args;
args.emplace_back(std::move(owned_buffers));
std::vector<ExecutionInput> args;
args.emplace_back(ExecutionInput(std::move(owned_buffers)));
TF_ASSERT_OK_AND_ASSIGN(
ExecutionOutput output,

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
namespace xla {
namespace {
class GetDimensionSizeTest : public HloTestBase {};
// Test that the interpreter can correctly compute get_dimension_size.
TEST_F(GetDimensionSizeTest, DoIt) {
const char* const kModuleStr = R"(
HloModule a_inference_call_110__.55
ENTRY %a_inference_call_110__.55 (arg0.1: f32[1,8], arg1.2: f32[8], arg2.3: f32[8]) -> s32[] {
%constant.37 = f32[] constant(1e-12)
%broadcast.38 = f32[1,1]{1,0} broadcast(f32[] %constant.37), dimensions={}
%arg0.1 = f32[1,8]{1,0} parameter(0), parameter_replication={false}
%reshape.4 = f32[1,8]{1,0} reshape(f32[1,8]{1,0} %arg0.1)
%convert.5 = f32[1,8]{1,0} convert(f32[1,8]{1,0} %reshape.4)
%constant.6 = f32[] constant(0)
%convert.7 = f32[] convert(f32[] %constant.6)
ROOT %get-dimension-size.13 = s32[] get-dimension-size(f32[1,8]{1,0} %convert.5), dimensions={1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01}));
}
} // anonymous namespace
} // namespace xla

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
@ -39,8 +40,8 @@ static std::array<bool, 1> use_bfloat16_params{false};
#endif
struct ReverseSpec {
absl::Span<const int64> input_dims;
absl::Span<const int64> reversal;
std::vector<int64> input_dims;
std::vector<int64> reversal;
bool use_bfloat16;
string ToTestCaseName() const {

View File

@ -72,6 +72,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",

Some files were not shown because too many files have changed in this diff Show More