Merge branch 'master' into interface_16x8
This commit is contained in:
commit
eaffdc0340
2
.bazelrc
2
.bazelrc
@ -137,6 +137,7 @@ build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain
|
||||
# environment variable "TF_MKL_ROOT" every time before build.
|
||||
build:mkl --define=build_with_mkl=true --define=enable_mkl=true
|
||||
build:mkl --define=tensorflow_mkldnn_contraction_kernel=0
|
||||
build:mkl --define=build_with_mkl_dnn_v1_only=true
|
||||
build:mkl -c opt
|
||||
|
||||
# This config option is used to enable MKL-DNN open source library only,
|
||||
@ -238,7 +239,6 @@ build:c++1z --config=c++17
|
||||
|
||||
# Enable using platform specific build settings
|
||||
build --enable_platform_specific_config
|
||||
build --config=xla
|
||||
|
||||
# Suppress C++ compiler warnings, otherwise build logs become 10s of MBs.
|
||||
build:linux --copt=-w
|
||||
|
3
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
3
.github/ISSUE_TEMPLATE/30-feature-request.md
vendored
@ -1,5 +1,6 @@
|
||||
---
|
||||
name: Feature Request about: Use this template for raising a feature request
|
||||
name: Feature Request
|
||||
about: Use this template for raising a feature request
|
||||
labels: 'type:feature'
|
||||
|
||||
---
|
||||
|
16
README.md
16
README.md
@ -130,18 +130,20 @@ Build Type | Status
|
||||
## Resources
|
||||
|
||||
* [TensorFlow.org](https://www.tensorflow.org)
|
||||
* [TensorFlow tutorials](https://www.tensorflow.org/tutorials/)
|
||||
* [TensorFlow official models](https://github.com/tensorflow/models/tree/master/official)
|
||||
* [TensorFlow examples](https://github.com/tensorflow/examples)
|
||||
* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
|
||||
* [TensorFlow Official Models](https://github.com/tensorflow/models/tree/master/official)
|
||||
* [TensorFlow Examples](https://github.com/tensorflow/examples)
|
||||
* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice)
|
||||
* [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment)
|
||||
* [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187)
|
||||
* [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190)
|
||||
* [TensorFlow blog](https://blog.tensorflow.org)
|
||||
* [TensorFlow Blog](https://blog.tensorflow.org)
|
||||
* [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml)
|
||||
* [TensorFlow Twitter](https://twitter.com/tensorflow)
|
||||
* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
|
||||
* [TensorFlow roadmap](https://www.tensorflow.org/community/roadmap)
|
||||
* [TensorFlow white papers](https://www.tensorflow.org/about/bib)
|
||||
* [TensorBoard visualization toolkit](https://github.com/tensorflow/tensorboard)
|
||||
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
|
||||
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
|
||||
* [TensorBoard Visualization Toolkit](https://github.com/tensorflow/tensorboard)
|
||||
|
||||
Learn more about the
|
||||
[TensorFlow community](https://www.tensorflow.org/community) and how to
|
||||
|
@ -1390,7 +1390,7 @@ def main():
|
||||
else:
|
||||
environ_cp['TF_CONFIGURE_IOS'] = '0'
|
||||
|
||||
if environ_cp.get('TF_ENABLE_XLA', 1):
|
||||
if environ_cp.get('TF_ENABLE_XLA', '1') == '1':
|
||||
write_to_bazelrc('build --config=xla')
|
||||
|
||||
set_action_env_var(
|
||||
|
@ -1213,10 +1213,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
if (custom_device == nullptr) {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, device, context, &ret_handle);
|
||||
std::move(t), device, device, context, &ret_handle);
|
||||
} else {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, custom_device, context, &ret_handle);
|
||||
std::move(t), custom_device, context, &ret_handle);
|
||||
}
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
|
@ -307,14 +307,12 @@ cc_library(
|
||||
"transforms/runtime_type_verify.cc",
|
||||
"transforms/split_merged_operands.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
"transforms/while_loop_outline.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"ir/tfl_ops_interface.h.inc",
|
||||
"transforms/dilated_conv.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
@ -327,6 +325,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -534,7 +534,7 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> {
|
||||
let summary = "ArgMin operator";
|
||||
|
||||
let description = [{
|
||||
Returns the index with the smallest value across dimensions of a tensor."
|
||||
Returns the index with the smallest value across dimensions of a tensor.
|
||||
a = [1, 10, 26.9, 2.8, 166.32, 62.3]
|
||||
b = tf.math.argmin(input = a)
|
||||
c = tf.keras.backend.eval(b)
|
||||
@ -3181,19 +3181,19 @@ def TFL_LSTMOp :
|
||||
Long short-term memory unit (LSTM) recurrent network layer.
|
||||
The default non-peephole implementation is based on:
|
||||
http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
|
||||
S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation,
|
||||
S. Hochreiter and J. Schmidhuber. 'Long Short-Term Memory'. Neural Computation,
|
||||
9(8):1735-1780, 1997.
|
||||
The peephole implementation is based on:
|
||||
https://research.google.com/pubs/archive/43905.pdf
|
||||
Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory
|
||||
recurrent neural network architectures for large scale acoustic modeling.
|
||||
Hasim Sak, Andrew Senior, and Francoise Beaufays. 'Long short-term memory
|
||||
recurrent neural network architectures for large scale acoustic modeling.'
|
||||
INTERSPEECH, 2014.
|
||||
The coupling of input and forget gate (CIFG) is based on:
|
||||
http://arxiv.org/pdf/1503.04069.pdf
|
||||
Greff et al. "LSTM: A Search Space Odyssey"
|
||||
Greff et al. 'LSTM: A Search Space Odyssey'
|
||||
The layer normalization is based on:
|
||||
https://arxiv.org/pdf/1607.06450.pdf
|
||||
Ba et al. “Layer Normalization”
|
||||
Ba et al. 'Layer Normalization'
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
#include "tensorflow/lite/toco/toco_flags.pb.h"
|
||||
@ -197,60 +198,60 @@ Status RegisterCustomBuiltinOps(const std::vector<string> extra_tf_opdefs) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
Status RegisterAllCustomOps(const toco::TocoFlags& toco_flags) {
|
||||
// Register any custom OpDefs.
|
||||
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
||||
toco_flags.custom_opdefs().end());
|
||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
|
||||
return RegisterCustomBuiltinOps(extra_tf_opdefs);
|
||||
}
|
||||
|
||||
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const GraphDef& input,
|
||||
string* result) {
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig specs;
|
||||
mlir::TFL::QuantizationSpecs quant_specs;
|
||||
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<double> node_mins;
|
||||
std::vector<double> node_maxs;
|
||||
quant_specs.inference_input_type =
|
||||
Status PopulateQuantizationSpecs(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
mlir::TFL::QuantizationSpecs* quant_specs,
|
||||
std::vector<string>* node_names,
|
||||
std::vector<string>* node_dtypes,
|
||||
std::vector<std::vector<int>>* node_shapes,
|
||||
std::vector<double>* node_mins,
|
||||
std::vector<double>* node_maxs) {
|
||||
quant_specs->inference_input_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_input_type());
|
||||
tensorflow::DataType inference_type =
|
||||
ConvertIODataTypeToDataType(toco_flags.inference_type());
|
||||
// Use non-float flag `inference_input_type` to override the `inference_type`
|
||||
// because we have to apply quantization to satisfy that.
|
||||
if (quant_specs.inference_input_type != tensorflow::DT_FLOAT) {
|
||||
inference_type = quant_specs.inference_input_type;
|
||||
if (quant_specs->inference_input_type != tensorflow::DT_FLOAT) {
|
||||
inference_type = quant_specs->inference_input_type;
|
||||
}
|
||||
|
||||
for (auto& flag : model_flags.input_arrays()) {
|
||||
node_names.push_back(flag.name());
|
||||
node_names->push_back(flag.name());
|
||||
// TOCO doesn't required `data_type` to be filled for every input.
|
||||
// If it's not filled, make it an empty string so the importer will use
|
||||
// the data type in the NodeDef.
|
||||
auto toco_data_type = flag.data_type();
|
||||
if (toco_data_type == ::toco::IODataType::IO_DATA_TYPE_UNKNOWN) {
|
||||
node_dtypes.push_back("");
|
||||
node_dtypes->push_back("");
|
||||
} else {
|
||||
node_dtypes.push_back(
|
||||
node_dtypes->push_back(
|
||||
DataType_Name(ConvertIODataTypeToDataType(toco_data_type)));
|
||||
}
|
||||
node_shapes.push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
node_shapes->push_back(std::vector<int>(flag.shape().dims().begin(),
|
||||
flag.shape().dims().end()));
|
||||
// Currently, only UINT8 and INT8 require inputs stats
|
||||
if (inference_type == DT_QINT8 || inference_type == DT_QUINT8) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto min_max, InputStatsToMinMax(flag.mean_value(), flag.std_value(),
|
||||
inference_type));
|
||||
node_mins.push_back(min_max.first);
|
||||
node_maxs.push_back(min_max.second);
|
||||
node_mins->push_back(min_max.first);
|
||||
node_maxs->push_back(min_max.second);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(tensorflow::ParseInputArrayInfo(
|
||||
node_names, node_dtypes, node_shapes, &specs.inputs));
|
||||
if (mlir::TFL::GetInputNodeQuantSpecs(node_names, node_mins, node_maxs,
|
||||
inference_type, &quant_specs)) {
|
||||
|
||||
if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs,
|
||||
inference_type, quant_specs)) {
|
||||
return errors::InvalidArgument("Failed to get input quant spec.");
|
||||
}
|
||||
|
||||
@ -258,49 +259,34 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
// quantization is enabled, `inference_type` and `inference_input_type` are
|
||||
// not used by MLIR passes.
|
||||
if (toco_flags.post_training_quantize()) {
|
||||
quant_specs.weight_quantization = true;
|
||||
quant_specs->weight_quantization = true;
|
||||
if (toco_flags.quantize_to_float16()) {
|
||||
quant_specs.inference_type = tensorflow::DT_HALF;
|
||||
quant_specs.inference_input_type = tensorflow::DT_HALF;
|
||||
quant_specs->inference_type = tensorflow::DT_HALF;
|
||||
quant_specs->inference_input_type = tensorflow::DT_HALF;
|
||||
} else {
|
||||
quant_specs.inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs.inference_input_type = tensorflow::DT_QINT8;
|
||||
quant_specs->inference_type = tensorflow::DT_QINT8;
|
||||
quant_specs->inference_input_type = tensorflow::DT_QINT8;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse output arrays.
|
||||
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
|
||||
model_flags.output_arrays().end());
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
||||
|
||||
// Other flags.
|
||||
if (toco_flags.has_default_ranges_min()) {
|
||||
quant_specs.default_ranges.first = toco_flags.default_ranges_min();
|
||||
quant_specs->default_ranges.first = toco_flags.default_ranges_min();
|
||||
}
|
||||
if (toco_flags.has_default_ranges_max()) {
|
||||
quant_specs.default_ranges.second = toco_flags.default_ranges_max();
|
||||
quant_specs->default_ranges.second = toco_flags.default_ranges_max();
|
||||
}
|
||||
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module,
|
||||
mlir::TFL::QuantizationSpecs quant_specs,
|
||||
string* result) {
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
bool emit_select_tf_ops = toco_flags.enable_select_tf_ops();
|
||||
bool emit_custom_ops = toco_flags.allow_custom_ops();
|
||||
specs.prune_unused_nodes = true;
|
||||
specs.convert_legacy_fed_inputs = true;
|
||||
specs.graph_as_function = false;
|
||||
specs.upgrade_legacy = true;
|
||||
WarningUnusedFlags(model_flags, toco_flags);
|
||||
|
||||
// Register any custom OpDefs.
|
||||
std::vector<string> extra_tf_opdefs(toco_flags.custom_opdefs().begin(),
|
||||
toco_flags.custom_opdefs().end());
|
||||
extra_tf_opdefs.push_back(kDetectionPostProcessOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceLstmOp);
|
||||
extra_tf_opdefs.push_back(kUnidirectionalSequenceRnnOp);
|
||||
TF_RETURN_IF_ERROR(RegisterCustomBuiltinOps(extra_tf_opdefs));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||
|
||||
if (toco_flags.has_dump_graphviz_dir()) {
|
||||
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
|
||||
@ -334,4 +320,52 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
const toco::TocoFlags& toco_flags,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const GraphDef& input,
|
||||
string* result) {
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig specs;
|
||||
mlir::TFL::QuantizationSpecs quant_specs;
|
||||
|
||||
// Parse input arrays.
|
||||
std::vector<string> node_names;
|
||||
std::vector<string> node_dtypes;
|
||||
std::vector<std::vector<int>> node_shapes;
|
||||
std::vector<double> node_mins;
|
||||
std::vector<double> node_maxs;
|
||||
|
||||
// Populate quantization specs.
|
||||
TF_RETURN_IF_ERROR(PopulateQuantizationSpecs(
|
||||
model_flags, toco_flags, &quant_specs, &node_names, &node_dtypes,
|
||||
&node_shapes, &node_mins, &node_maxs));
|
||||
|
||||
TF_RETURN_IF_ERROR(tensorflow::ParseInputArrayInfo(
|
||||
node_names, node_dtypes, node_shapes, &specs.inputs));
|
||||
|
||||
// Parse output arrays.
|
||||
std::vector<string> output_arrays(model_flags.output_arrays().begin(),
|
||||
model_flags.output_arrays().end());
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::ParseOutputArrayInfo(output_arrays, &specs.outputs));
|
||||
|
||||
specs.prune_unused_nodes = true;
|
||||
specs.convert_legacy_fed_inputs = true;
|
||||
specs.graph_as_function = false;
|
||||
specs.upgrade_legacy = true;
|
||||
WarningUnusedFlags(model_flags, toco_flags);
|
||||
|
||||
// Register all custom ops, including user-specified custom ops.
|
||||
TF_RETURN_IF_ERROR(RegisterAllCustomOps(toco_flags));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
|
||||
|
||||
return ConvertMLIRToTFLiteFlatBuffer(toco_flags, std::move(module),
|
||||
quant_specs, result);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -180,13 +180,12 @@ func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor
|
||||
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_19:%.*]] = constant unit
|
||||
// CHECK: [[VAL_20:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
||||
// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
||||
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_22:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
|
||||
// CHECK: return [[VAL_21]], [[VAL_25:%.*]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: return [[VAL_21]], [[VAL_20]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
@ -221,15 +220,14 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te
|
||||
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_21:%.*]] = constant unit
|
||||
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_23:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_25:%.*]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_24:%.*]] = "tf.Transpose"([[VAL_22]], [[VAL_23]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
|
||||
// CHECK: return [[VAL_26]], [[VAL_24]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
|
||||
// CHECK: return [[VAL_25]], [[VAL_24]], [[VAL_26]], [[VAL_27]], [[VAL_28]] : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
@ -264,13 +262,12 @@ func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor<?x8x8xf32>,
|
||||
// CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_21:%.*]] = constant unit
|
||||
// CHECK: [[VAL_22:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
||||
// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
||||
// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
|
||||
// CHECK: return [[VAL_23]], [[VAL_27:%.*]], [[VAL_24]], [[VAL_25]], [[VAL_26]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: return [[VAL_23]], [[VAL_22]], [[VAL_24]], [[VAL_25]], [[VAL_26]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: }
|
||||
|
||||
}
|
||||
@ -308,16 +305,92 @@ func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf3
|
||||
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_22:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_20]], [[VAL_21]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_23:%.*]] = constant unit
|
||||
// CHECK: [[VAL_24:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) ( {
|
||||
// CHECK: }) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_24:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_16]]#0, [[VAL_16]]#1, [[VAL_16]]#2, [[VAL_16]]#3, [[VAL_19]]#0, [[VAL_19]]#1, [[VAL_19]]#2, [[VAL_19]]#3, [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_22]]#0, [[VAL_22]]#1, [[VAL_22]]#2, [[VAL_22]]#3, [[VAL_23]], [[VAL_23]], [[VAL_1]], [[VAL_2]], [[VAL_23]], [[VAL_23]], [[VAL_23]], [[VAL_23]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_25:%.*]] = constant dense<[1, 0, 2]> : tensor<3xi64>
|
||||
// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_27:%.*]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_26:%.*]] = "tf.Transpose"([[VAL_24]], [[VAL_25]]) : (tensor<8x8x10xf32>, tensor<3xi64>) -> tensor<8x8x10xf32>
|
||||
// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_30:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_31:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
|
||||
// CHECK: return [[VAL_28]], [[VAL_26]], [[VAL_29]], [[VAL_30]], [[VAL_31]] : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: [[VAL_30:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
|
||||
// CHECK: return [[VAL_27]], [[VAL_26]], [[VAL_28]], [[VAL_29]], [[VAL_30]] : tensor<?x10xf32>, tensor<8x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: }
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_can_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) {
|
||||
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse} : (tensor<?x8x8xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>)
|
||||
%2 = "tf.Add"(%0, %1#1) : (tensor<f32>, tensor<?x8x10xf32>) -> tensor<?x8x10xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_time_major_can_fuse([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
|
||||
// CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi64>) -> tensor<40x8xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK: [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi64>) -> tensor<40x10xf32>
|
||||
// CHECK: [[VAL_10:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_11:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
|
||||
// CHECK: [[VAL_13:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_14:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
|
||||
// CHECK: [[VAL_16:%.*]] = "tf.Const"() {value = dense<10> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
// CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
|
||||
// CHECK: [[VAL_19:%.*]] = constant unit
|
||||
// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<?x10xf32>, tensor<?x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
|
||||
// CHECK: [[VAL_21:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_22:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_23:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_24:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<f32>
|
||||
// CHECK: return [[VAL_21]], [[VAL_20]], [[VAL_22]], [[VAL_23]], [[VAL_24]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func @inference_cannot_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) {
|
||||
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
%1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_cannot_fuse} : (tensor<?x8x8xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>)
|
||||
%2 = "tf.Add"(%0, %1#2) : (tensor<f32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @inference_standard_lstm_time_major_cannot_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||
%1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||
%2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||
%3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||
%4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||
%5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||
%6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @inference_standard_lstm_time_major_cannot_fuse([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
|
||||
// CHECK: [[VAL_6:%.*]] = "tf.BatchMatMulV2"([[VAL_0]], [[VAL_3]]) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
|
||||
// CHECK: [[VAL_7:%.*]] = "tf.Add"([[VAL_6]], [[VAL_5]]) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
|
||||
// CHECK: [[VAL_8:%.*]] = "tf.BatchMatMulV2"([[VAL_7]], [[VAL_4]]) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
|
||||
// CHECK: [[VAL_9:%.*]] = "tf.Add"([[VAL_8]], [[VAL_1]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||
// CHECK: [[VAL_10:%.*]] = "tf.Add"([[VAL_8]], [[VAL_2]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
|
||||
// CHECK: [[VAL_11:%.*]] = "tf.Add"([[VAL_1]], [[VAL_2]]) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
|
||||
// CHECK: [[VAL_12:%.*]] = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: return [[VAL_11]], [[VAL_10]], [[VAL_11]], [[VAL_11]], [[VAL_12]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
@ -347,6 +347,16 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
|
||||
Value element_shape = operands[0];
|
||||
Type shape_dtype = getElementTypeOrSelf(element_shape.getType());
|
||||
// If the `element_shape` is a scalar, we know that it's dynamic shape
|
||||
// and returns an error.
|
||||
if (auto shaped_type = element_shape.getType().dyn_cast<ShapedType>()) {
|
||||
if (shaped_type.getRank() == 0) {
|
||||
op.emitError(
|
||||
"requires element_shape to be 1D tensor during TF Lite "
|
||||
"transformation pass");
|
||||
return ConversionPattern::matchFailure();
|
||||
}
|
||||
}
|
||||
|
||||
DenseIntElementsAttr dense_elem_attr;
|
||||
if (matchPattern(element_shape, m_Constant(&dense_elem_attr))) {
|
||||
|
@ -61,7 +61,7 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateTrimFunctionsPass(
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
|
||||
// pass.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareCompositeFunctionsPass();
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect ExtractOphint pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateExtractOphintPass();
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Analysis/CallInterfaces.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
@ -35,10 +36,12 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
|
||||
@ -91,14 +94,14 @@ class ConvertEmbeddedLookupFunc {
|
||||
// body with the corresponding fused TFLite op. The replacement need not always
|
||||
// be a fused op, though that is the primary use case.
|
||||
class PrepareCompositeFunctionsPass
|
||||
: public FunctionPass<PrepareCompositeFunctionsPass> {
|
||||
: public ModulePass<PrepareCompositeFunctionsPass> {
|
||||
public:
|
||||
explicit PrepareCompositeFunctionsPass() {}
|
||||
|
||||
private:
|
||||
void ConvertTFImplements(FuncOp func, StringAttr attr);
|
||||
void ConvertTFAPIImplements(FuncOp func, StringAttr attr);
|
||||
void runOnFunction() override;
|
||||
void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module);
|
||||
void runOnModule() override;
|
||||
};
|
||||
|
||||
void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
|
||||
@ -131,14 +134,54 @@ void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult CheckOutputConsumer(
|
||||
Operation* call_op, int expected_num_outputs,
|
||||
llvm::DenseSet<int> expected_consumer_indices) {
|
||||
if (call_op->getNumResults() != expected_num_outputs) return failure();
|
||||
|
||||
for (int i = 0; i < expected_num_outputs; ++i) {
|
||||
auto it = expected_consumer_indices.find(i);
|
||||
if (it != expected_consumer_indices.end()) {
|
||||
// Expected consumer.
|
||||
if (call_op->getResult(i).use_empty()) return failure();
|
||||
} else {
|
||||
// Unexpected consumer.
|
||||
if (!call_op->getResult(i).use_empty()) return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) {
|
||||
bool check_failed = false;
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
func.walk([&](Operation* op) {
|
||||
auto call_op = dyn_cast_or_null<CallOpInterface>(op);
|
||||
if (call_op && op->getAttrOfType<SymbolRefAttr>("f").getRootReference() ==
|
||||
lstm_func.getName()) {
|
||||
// Keras LSTM have 5 outputs.
|
||||
// We should make sure only the second output is consumed.
|
||||
if (failed(CheckOutputConsumer(call_op, 5, {1}))) check_failed = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (check_failed) return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
|
||||
StringAttr attr) {
|
||||
StringAttr attr,
|
||||
ModuleOp module) {
|
||||
// Keras lstm tf.api_implements usually has attribute like "lstm_abcde91...".
|
||||
// TODO(b/147436982): we need to make sure that only the
|
||||
// outputs(full sequence) is used, not the last_output, not the new_states.
|
||||
// We will discard everything except the outputs.
|
||||
// And the outputs is in the shape of [batch, time, units].
|
||||
if (attr.getValue().startswith("lstm_")) {
|
||||
// Check if the keras lstm can be fused, if not, we just don't do anything.
|
||||
if (failed(CheckFusableKerasLstm(func, module))) return;
|
||||
|
||||
func.eraseBody();
|
||||
func.addEntryBlock();
|
||||
|
||||
@ -148,26 +191,29 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func,
|
||||
}
|
||||
}
|
||||
|
||||
void PrepareCompositeFunctionsPass::runOnFunction() {
|
||||
auto func = getFunction();
|
||||
// We have two kinds of implements:
|
||||
// 1) tf._implements.
|
||||
// 2) tf.api_implements.
|
||||
// We need to handle them separately.
|
||||
auto tf_implements_attr = func.getAttrOfType<StringAttr>(kTFImplements);
|
||||
if (tf_implements_attr) {
|
||||
ConvertTFImplements(func, tf_implements_attr);
|
||||
} else {
|
||||
void PrepareCompositeFunctionsPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
// We have two kinds of implements:
|
||||
// 1) tf._implements.
|
||||
// 2) tf.api_implements.
|
||||
// We need to handle them separately.
|
||||
auto tf_implements_attr = func.getAttrOfType<StringAttr>(kTFImplements);
|
||||
if (tf_implements_attr) {
|
||||
ConvertTFImplements(func, tf_implements_attr);
|
||||
}
|
||||
|
||||
auto tf_api_implements_attr =
|
||||
func.getAttrOfType<StringAttr>(kTFAPIImplements);
|
||||
if (!tf_api_implements_attr) return;
|
||||
// TODO(b/147536816): Keras lstm should set up the correct attributes.
|
||||
ConvertTFAPIImplements(func, tf_api_implements_attr);
|
||||
if (tf_api_implements_attr) {
|
||||
// TODO(b/147536816): Keras lstm should set up the correct attributes.
|
||||
ConvertTFAPIImplements(func, tf_api_implements_attr, module);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreatePrepareCompositeFunctionsPass() {
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreatePrepareCompositeFunctionsPass() {
|
||||
return std::make_unique<PrepareCompositeFunctionsPass>();
|
||||
}
|
||||
|
||||
|
@ -53,10 +53,10 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
|
||||
@ -652,8 +652,8 @@ void PrepareTFPass::runOnFunction() {
|
||||
patterns.clear();
|
||||
TFL::populateWithGenerated(ctx, &patterns);
|
||||
if (unfold_batch_matmul_) {
|
||||
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
|
||||
patterns.insert<TF::ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
TF::ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(ctx);
|
||||
}
|
||||
patterns.insert<ConvertTFConv2D, ConvertTFDepthwiseConv2dNative,
|
||||
ConvertTFStridedSlice>(ctx);
|
||||
|
@ -692,7 +692,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
|
||||
Value none = builder->create<mlir::ConstantOp>(
|
||||
func_op.getLoc(), builder->getNoneType(), builder->getUnitAttr());
|
||||
auto lstm = builder->create<mlir::TFL::LSTMOp>(
|
||||
auto lstm = builder->create<mlir::TFL::UnidirectionalSequenceLSTMOp>(
|
||||
func_op.getLoc(), result_type, /*input=*/input,
|
||||
/*input_to_input_weights=*/weights_array->getResult(0),
|
||||
/*input_to_forget_weights=*/weights_array->getResult(1),
|
||||
@ -718,7 +718,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::FuncOp func_op, OpBuilder* builder) {
|
||||
/*cell_layer_norm_coefficients=*/none,
|
||||
/*output_layer_norm_coefficients=*/none, builder->getStringAttr("TANH"),
|
||||
builder->getF32FloatAttr(10.0), builder->getF32FloatAttr(0.0),
|
||||
builder->getStringAttr("FULL"));
|
||||
builder->getBoolAttr(true));
|
||||
|
||||
auto final_output = lstm.getResult();
|
||||
if (!time_majored) {
|
||||
|
@ -198,6 +198,7 @@ cc_library(
|
||||
"ir/tf_verifiers.h",
|
||||
"transforms/bridge.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
"@llvm-project//mlir:include/mlir/Analysis/CallInterfaces.h",
|
||||
"@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h",
|
||||
],
|
||||
@ -297,6 +298,7 @@ cc_library(
|
||||
"transforms/shape_inference.cc",
|
||||
"transforms/shape_inference_pass.cc",
|
||||
"transforms/sink_constant.cc",
|
||||
"transforms/stack_ops_decomposition.cc",
|
||||
"transforms/test_side_effect_analysis.cc",
|
||||
"transforms/tf_device_assignment.cc",
|
||||
"transforms/tpu_cluster_formation.cc",
|
||||
@ -304,7 +306,9 @@ cc_library(
|
||||
"transforms/tpu_dynamic_padding_mapper.cc",
|
||||
"transforms/tpu_merge_variables_with_execute.cc",
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
"transforms/tpu_sharding_identification_pass.cc",
|
||||
"transforms/tpu_variable_runtime_reformatting.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
"translate/breakup-islands.cc",
|
||||
"translate/control_to_executor_dialect.cc",
|
||||
"translate/executor_to_control_dialect.cc",
|
||||
@ -314,6 +318,7 @@ cc_library(
|
||||
"transforms/bridge.h",
|
||||
"transforms/passes.h",
|
||||
"transforms/shape_inference.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
@ -335,6 +340,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite:validators",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:sharding_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
@ -342,12 +348,14 @@ cc_library(
|
||||
"//tensorflow/core/platform:random",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
@ -396,6 +404,7 @@ cc_library(
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":convert_type",
|
||||
":error_util",
|
||||
":export_tf_dialect_op",
|
||||
":export_utils",
|
||||
":mangling_util",
|
||||
@ -972,6 +981,18 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_local_var_op",
|
||||
srcs = ["ops/mlir_local_var_op.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_mlir_passthrough_op_py",
|
||||
out = "gen_mlir_passthrough_op.py",
|
||||
|
@ -99,6 +99,24 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context)
|
||||
addInterfaces<TFInlinerInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_device.launch
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Checks if a tf_device.launch wraps a single operation and the single
|
||||
// operation results are perfectly forwarded to the launch return.
|
||||
bool LaunchOp::WrapsSingleOp() {
|
||||
auto body = GetBody().without_terminator();
|
||||
if (!has_single_element(body)) return false;
|
||||
|
||||
Operation& wrapped_op = *body.begin();
|
||||
Operation* terminator = GetBody().getTerminator();
|
||||
return wrapped_op.getNumResults() == terminator->getNumOperands() &&
|
||||
std::equal(wrapped_op.getResults().begin(),
|
||||
wrapped_op.getResults().end(),
|
||||
terminator->getOperands().begin());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf_device.return
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -66,6 +66,7 @@ def TfDevice_LaunchOp : TfDevice_Op<"launch",
|
||||
let extraClassDeclaration = [{
|
||||
Block &GetBody() { return getOperation()->getRegion(0).front(); }
|
||||
StringRef getDevice() { return device(); }
|
||||
bool WrapsSingleOp();
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
@ -129,26 +130,27 @@ def TfDevice_ReplicateOp :
|
||||
let summary = "Wraps an N-way replicated computation.";
|
||||
|
||||
let description = [{
|
||||
The region held by this operation represents a computation that is
|
||||
replicated across multiple devices. The number of replications is based
|
||||
on the `n` attribute. Explicit devices can be populated in the `devices`
|
||||
attribute, and it must be a map of device alias and a list of explicit
|
||||
device names or aliased device names from outer scope. Device name list
|
||||
specifies devices on which replicated ops inside tf_device.replicate will
|
||||
be executed. If tf_device.parallel_execute is inside replicate op region,
|
||||
region inside replicate may represent computations across larger set of
|
||||
devices. In that case, device alias can be used to specify device assignment
|
||||
and replication of each concurrent execution (i.e. region) defined by
|
||||
parallel_execute op. The size of the device name list must match `n`.
|
||||
Within a replica, the execution semantics follow standard sequential behavior.
|
||||
Ops in the replicate with no device assigned will have its device set to the
|
||||
associated replicate device from `devices`. Operands are replicated inputs:
|
||||
each group of `n` inputs corresponds to an input for a single individual
|
||||
replica and is mapped to a single region argument. Inside one group the
|
||||
operands are matching in order the `devices` attribute. Each replicated
|
||||
input must have compatible shapes and types. Operands not replicated can
|
||||
be implicitly captured by ops in the region. Results are replicated each
|
||||
from the regions terminator.
|
||||
The region held by this operation represents a computation that is replicated
|
||||
across multiple devices. The number of replications is based on the `n`
|
||||
attribute. Explicit devices can be populated in the `devices` attribute, and it
|
||||
must be a mapping of device alias to list of explicit or aliased device names
|
||||
from the outer scope. The device name map specifies devices on which replicated
|
||||
ops inside tf_device.replicate will be executed. A tf_device.parallel_execute
|
||||
inside the tf_device.replicate op region may be used to represent computations
|
||||
across a larger set of devices. In that case, the device alias can be used to
|
||||
specify device assignment and replication of each concurrent execution
|
||||
(i.e. region) defined by tf_device.parallel_execute op. The size of each value
|
||||
list in the device name map must match `n`. Within a replica, the execution
|
||||
semantics follow standard sequential behavior. Ops in the tf_device.replicate
|
||||
wrapped with a tf_device.launch will have its device set to the associated
|
||||
replicated device from `devices` if the tf_device.launch refers to an aliased
|
||||
device name. Otherwise the device already set in tf_device.launch is used
|
||||
instead. Operands are replicated inputs: each group of `n` inputs corresponds to
|
||||
an input for a single individual replica and is mapped to a single region
|
||||
argument. Inside one group the operands are matching in order the `devices`
|
||||
attribute. Each replicated input must have compatible shapes and types. Operands
|
||||
not replicated can be implicitly captured by ops in the region. Results are
|
||||
replicated each from the regions terminator.
|
||||
|
||||
For example:
|
||||
```
|
||||
@ -156,22 +158,45 @@ For example:
|
||||
%1 = "tf.opB"() : () -> tensor<i32>
|
||||
%2 = "tf.opC"() : () -> tensor<f32>
|
||||
%3 = "tf.opD"() : () -> tensor<f32>
|
||||
%4 = "tf.opE"() : () -> tensor<i1>
|
||||
%output:4 = tf_device.replicate([%0, %1] as %input_0:tensor<i32>,
|
||||
[%2, %3] as %input_1:tensor<f32>)
|
||||
%4 = "tf.opE"() : () -> tensor<!tf.resource>
|
||||
%5 = "tf.opF"() : () -> tensor<!tf.resource>
|
||||
%6 = "tf.opG"() : () -> tensor<!tf.string>
|
||||
%7 = "tf.opH"() : () -> tensor<!tf.string>
|
||||
%8 = "tf.opI"() : () -> tensor<i1>
|
||||
%output:8 = tf_device.replicate([%0, %1] as %input_0:tensor<i32>,
|
||||
[%2, %3] as %input_1:tensor<f32>,
|
||||
[%4, %5] as %input_2:tensor<!tf.resource>
|
||||
[%6, %7] as %input_3:tensor<!tf.string>)
|
||||
{n = 2 : i32,
|
||||
devices = { DEVICE_ALIAS = ["/DEVICE:0", "/DEVICE:1"]}} {
|
||||
// Inside the region, %input_0 corresponds to %0 for "/DEVICE:0" and %1 for
|
||||
// "/DEVICE:1", and %input_1 corresponds to %2 for "/DEVICE:0" and %3 for
|
||||
// "/DEVICE:1".
|
||||
%f = "tf.opF"(%input_0, %4) : (tensor<i32>, tensor<i1>) -> tensor<i32>
|
||||
%g = "tf.opG"(%input_1, %4) : (tensor<f32>, tensor<i1>) -> tensor<f32>
|
||||
tf_device.return %f, %g : tensor<i32>, tensor<f32>
|
||||
devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
|
||||
DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
|
||||
// Inside the region, %0, %2, %4, and %6 corresponds to
|
||||
// "/DEVICE:0"/"/DEVICE:2" and %1, %3, %5, and %7 corresponds to
|
||||
// "/DEVICE:1"/"/DEVICE:3", depending on which device alias is used.
|
||||
%j = "tf_device.launch"() ( {
|
||||
%9 = "tf.opJ"(%input_0, %6) : (tensor<i32>, tensor<i1>) -> tensor<i32>
|
||||
tf_device.return %9 : tensor<i32>
|
||||
}) {device = "DEVICE_ALIAS_0"} : () -> tensor<i32>
|
||||
%k = "tf_device.launch"() ( {
|
||||
%10 = "tf.opK"(%input_1, %6) : (tensor<f32>, tensor<i1>) -> tensor<f32>
|
||||
tf_device.return %10 : tensor<f32>
|
||||
}) {device = "DEVICE_ALIAS_1"} : () -> tensor<f32>
|
||||
%l = "tf_device.launch"() ( {
|
||||
%11 = "tf.opL"(%input_2, %6) : (tensor<!tf.resource>, tensor<i1>)
|
||||
-> tensor<!tf.resource>
|
||||
tf_device.return %11 : tensor<!tf.resource>
|
||||
}) {device = "/DEVICE:4"} : () -> tensor<f32>
|
||||
%m = "tf.opM"(%input_3, %6) : (tensor<!tf.string>, tensor<i1>)
|
||||
-> tensor<!tf.string>
|
||||
tf_device.return %j, %k, %l, %m :
|
||||
tensor<i32>, tensor<f32>, tensor<!tf.resource>, tensor<!tf.string>
|
||||
}
|
||||
// %output#0 corresponds to %f returned from "/DEVICE:0"
|
||||
// %output#1 corresponds to %f returned from "/DEVICE:1"
|
||||
// %output#2 corresponds to %g returned from "/DEVICE:0"
|
||||
// %output#3 corresponds to %g returned from "/DEVICE:1"
|
||||
// %output#0 corresponds to %j returned from "/DEVICE:0"
|
||||
// %output#1 corresponds to %j returned from "/DEVICE:1"
|
||||
// %output#2 corresponds to %k returned from "/DEVICE:2"
|
||||
// %output#3 corresponds to %k returned from "/DEVICE:3"
|
||||
// %output#4, %output#5 corresponds to %l and will be returned from "/DEVICE:4"
|
||||
// %output#6, %output#7 corresponds to %m and will have no device set
|
||||
```
|
||||
}];
|
||||
|
||||
@ -194,13 +219,11 @@ For example:
|
||||
let builders = [
|
||||
OpBuilder<"Builder* builder, OperationState& state, int n, "
|
||||
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
|
||||
"llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>>"
|
||||
" replicated_inputs, "
|
||||
"llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs, "
|
||||
"llvm::ArrayRef<Type> replica_output_types">,
|
||||
OpBuilder<"Builder* builder, OperationState& state, int n, "
|
||||
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
|
||||
"llvm::ArrayRef<std::pair<Operation::operand_range, Type>>"
|
||||
" replicated_inputs, "
|
||||
"llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs, "
|
||||
"Operation::result_type_range replica_output_types">
|
||||
];
|
||||
|
||||
|
@ -3926,6 +3926,21 @@ pad(t, paddings) ==> [[2, 1, 1, 2, 3, 3, 2]
|
||||
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_MlirLocalVarOp : TF_Op<"MlirLocalVarOp", []> {
|
||||
let summary = "Creates a handle to a in-scope variable.";
|
||||
|
||||
let description = [{
|
||||
Used by internal passes for temporary representation of local state, which will
|
||||
be eventually removed.
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs
|
||||
TF_ResourceTensor:$resource
|
||||
);
|
||||
}
|
||||
|
||||
def TF_MlirPassthroughOp : TF_Op<"MlirPassthroughOp", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Wraps an arbitrary MLIR computation expressed as a module with a main() function.
|
||||
@ -6418,6 +6433,74 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StackCloseV2Op : TF_Op<"StackCloseV2", []> {
|
||||
let summary = "Delete the stack from its resource container.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$handle
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_StackPopV2Op : TF_Op<"StackPopV2", []> {
|
||||
let summary = "Pop the element at the top of the stack.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$handle
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$elem
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr elem_type = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StackPushV2Op : TF_Op<"StackPushV2", []> {
|
||||
let summary = "Push an element onto the stack.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$handle,
|
||||
TF_Tensor:$elem,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$swap_memory
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_StackV2Op : TF_Op<"StackV2", []> {
|
||||
let summary = "A stack that produces elements in first-in last-out order.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
I32Tensor:$max_size,
|
||||
|
||||
TypeAttr:$elem_type,
|
||||
StrAttr:$stack_name
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_ResourceTensor:$handle
|
||||
);
|
||||
}
|
||||
|
||||
def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Stops gradient computation.";
|
||||
|
||||
@ -6701,6 +6784,24 @@ occurred during compilation.
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TPUCompileSucceededAssertOp : TF_Op<"TPUCompileSucceededAssert", []> {
|
||||
let summary = [{
|
||||
Asserts that compilation succeeded. This op produces no output and closes the
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
device during failure to ensure all pending device interactions fail.
|
||||
|
||||
'compilation_status' is a serialized CompilationResultProto.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_StrTensor:$compilation_status
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
def TF_TPUCopyWithLayoutOp : TF_Op<"TPUCopyWithLayout", [NoSideEffect]> {
|
||||
let summary = "Op that copies host tensor to device with specified layout.";
|
||||
|
||||
@ -6919,6 +7020,38 @@ is the corresponding input gradient.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorListConcatV2Op : TF_Op<"TensorListConcatV2", [NoSideEffect]> {
|
||||
let summary = "Concats all tensors in the list along the 0th dimension.";
|
||||
|
||||
let description = [{
|
||||
Requires that all tensors have the same shape except the first dimension.
|
||||
|
||||
input_handle: The input list.
|
||||
element_shape: The shape of the uninitialized elements in the list. If the first
|
||||
dimension is not -1, it is assumed that all list elements have the same
|
||||
leading dim.
|
||||
leading_dims: The list of leading dims of uninitialized list elements. Used if
|
||||
the leading dim of input_handle.element_shape or the element_shape input arg
|
||||
is not already set.
|
||||
tensor: The concated result.
|
||||
lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_handle,
|
||||
TF_I32OrI64Tensor:$element_shape,
|
||||
I64Tensor:$leading_dims
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$tensor,
|
||||
I64Tensor:$lengths
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_TensorListFromTensorOp : TF_Op<"TensorListFromTensor", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Creates a TensorList which, when stacked, has the value of `tensor`.
|
||||
@ -6980,6 +7113,33 @@ length: the number of tensors in the list
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TensorListPopBackOp : TF_Op<"TensorListPopBack", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns the last element of the input list as well as a list with all but that element.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
Fails if the list is empty.
|
||||
|
||||
input_handle: the input list
|
||||
tensor: the withdrawn last element of the list
|
||||
element_dtype: the type of elements in the list
|
||||
element_shape: the shape of the output tensor
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_VariantTensor:$input_handle,
|
||||
I32Tensor:$element_shape
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_VariantTensor:$output_handle,
|
||||
TF_Tensor:$tensor
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_TensorListPushBackOp : TF_Op<"TensorListPushBack", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`.
|
||||
@ -7823,3 +7983,38 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF__TPUCompileMlirOp : TF_Op<"_TPUCompileMlir", []> {
|
||||
let summary = [{
|
||||
Compiles a computations for execution on one or more TPU devices.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
For the internal use of the distributed TPU compiler. Note that currently only
|
||||
single TPU device is supported.
|
||||
|
||||
'mlir_module' is a serialized MLIR module with a `main` function that contains
|
||||
target computation.
|
||||
'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not
|
||||
known statically at TPUReplication rewrite time.
|
||||
'metadata' is a serialized TPUCompileMetadataProto describing
|
||||
the shapes and types of the inputs to the computation, as well as a mapping onto
|
||||
the TPU pod topology.
|
||||
'program' output is a string key that is passed to the _TPUExecute op and
|
||||
used to look up the program in the compilation cache.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<I64Tensor>:$dynamic_shapes,
|
||||
|
||||
StrAttr:$mlir_module,
|
||||
StrAttr:$metadata
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_StrTensor:$compilation_status,
|
||||
TF_StrTensor:$program
|
||||
);
|
||||
|
||||
TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>;
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef CUDA_CUDA_CONFIG_H_
|
||||
#define CUDA_CUDA_CONFIG_H_
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
|
||||
#define TF_CUDA_CAPABILITIES CudaVersion("3.0"), CudaVersion("6.0")
|
||||
namespace tensorflow {
|
||||
|
||||
#define TF_CUDA_VERSION "10.1"
|
||||
#define TF_CUDA_LIB_VERSION "10"
|
||||
#define TF_CUDNN_VERSION "7"
|
||||
REGISTER_OP("MlirLocalVarOp")
|
||||
.Output("resource: resource")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"(Creates a handle to a in-scope variable.
|
||||
Used by internal passes for temporary representation of local state, which will
|
||||
be eventually removed.)");
|
||||
|
||||
#define TF_CUDA_TOOLKIT_PATH "/usr/local/cuda-10.1"
|
||||
|
||||
#endif // CUDA_CUDA_CONFIG_H_
|
||||
} // namespace tensorflow
|
@ -3,7 +3,10 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
licenses(["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
data = [
|
||||
":debug_info_files",
|
||||
":test_utilities",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["pbtxt"],
|
||||
)
|
||||
@ -18,3 +21,13 @@ filegroup(
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
# Bundle together all the debug info files that are used by the tests.
|
||||
filegroup(
|
||||
name = "debug_info_files",
|
||||
srcs = glob(
|
||||
[
|
||||
"**/*.debug",
|
||||
],
|
||||
),
|
||||
)
|
||||
|
@ -0,0 +1,62 @@
|
||||
# RUN: not tf-mlir-translate -graphdef-to-mlir -tf-input-arrays=x,y -tf-input-data-types=DT_INT32,DT_INT32 -tf-input-shapes=2:3 -tf-output-arrays=x_y_sum %s --tf-debug-info=%s.debug -o - 2>&1 | FileCheck %s
|
||||
|
||||
# Checks that source debug information is used in the output error message.
|
||||
# CHECK: Graph import failed: Invalid argument: Dimensions must be equal
|
||||
# CHECK: math_ops.add(x, y, name='x_y_sum')
|
||||
# CHECK: build_graph(out_dir)
|
||||
node: {
|
||||
name: "x"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "shape"
|
||||
value: {
|
||||
shape: {
|
||||
dim: {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "y"
|
||||
op: "Placeholder"
|
||||
attr: {
|
||||
key: "shape"
|
||||
value: {
|
||||
shape: {
|
||||
dim: {
|
||||
size: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr: {
|
||||
key: "dtype"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node: {
|
||||
name: "x_y_sum"
|
||||
op: "Add"
|
||||
input: "x"
|
||||
input: "y"
|
||||
attr: {
|
||||
key: "T"
|
||||
value: {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
versions: {
|
||||
producer: 321
|
||||
}
|
||||
|
@ -0,0 +1,28 @@
|
||||
files : [ "org_tensorflow/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/error-message-with-source-info.pbtxt.fake_py.debug"]
|
||||
traces: {
|
||||
key : "x@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line : 1
|
||||
}
|
||||
}
|
||||
}
|
||||
traces: {
|
||||
key : "x_y_sum@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line : 3
|
||||
}
|
||||
file_line_cols: {
|
||||
line : 4
|
||||
}
|
||||
}
|
||||
}
|
||||
traces: {
|
||||
key : "y@"
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line : 2
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
x = value
|
||||
y = value
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
build_graph(out_dir)
|
@ -25,13 +25,16 @@ func @controls_per_replica() {
|
||||
// CHECK: %[[ISLAND_2:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_0]], %[[ISLAND_1]])
|
||||
|
||||
|
||||
// Tests devices are not set if no devices were defined in replicate.
|
||||
// Tests devices are not remapped if no devices were defined in replicate.
|
||||
// CHECK-LABEL: func @no_devices
|
||||
func @no_devices() {
|
||||
tf_executor.graph {
|
||||
%0 = tf_executor.island {
|
||||
tf_device.replicate {n = 2 : i32} {
|
||||
"tf.opA"() : () -> ()
|
||||
"tf_device.launch"() ( {
|
||||
"tf.opA"() : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "CORE_0"} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
tf_executor.yield
|
||||
@ -41,19 +44,22 @@ func @no_devices() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK-NOT: device
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK-NOT: device
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK: device = "CORE_0"
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK: device = "CORE_0"
|
||||
|
||||
|
||||
// Tests devices are not set if op already has a device assigned.
|
||||
// Tests devices are not remapped if device is not in replicate devices.
|
||||
// CHECK-LABEL: func @no_override_device
|
||||
func @no_override_device() {
|
||||
tf_executor.graph {
|
||||
%0 = tf_executor.island {
|
||||
tf_device.replicate {n = 2 : i32, devices = {CORE_0 = ["/CPU:0", "/GPU:1"]}} {
|
||||
"tf.opA"() {device = "/TPU:2"} : () -> ()
|
||||
"tf_device.launch"() ( {
|
||||
"tf.opA"() : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "/TPU:2"} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
tf_executor.yield
|
||||
@ -63,20 +69,22 @@ func @no_override_device() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK-SAME: device = "/TPU:2"
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK-SAME: device = "/TPU:2"
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK: device = "/TPU:2"
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK: device = "/TPU:2"
|
||||
|
||||
|
||||
// Tests devices are not set if op is not of the TF dialect.
|
||||
// CHECK-LABEL: func @no_device_non_tf_ops
|
||||
func @no_device_non_tf_ops() {
|
||||
// Tests devices are remapped if device is in replicate devices.
|
||||
// CHECK-LABEL: func @remap_device
|
||||
func @remap_device() {
|
||||
tf_executor.graph {
|
||||
%0 = tf_executor.island {
|
||||
tf_device.replicate {n = 2 : i32, devices = {CORE_0 = ["/CPU:0", "/GPU:1"]}} {
|
||||
"test.opA"() {device = "/TPU:2"} : () -> ()
|
||||
"test.opB"() : () -> ()
|
||||
"tf_device.launch"() ( {
|
||||
"tf.opA"() : () -> ()
|
||||
tf_device.return
|
||||
}) {device = "CORE_0"} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
tf_executor.yield
|
||||
@ -86,14 +94,10 @@ func @no_device_non_tf_ops() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: "test.opA"
|
||||
// CHECK-SAME: device = "/TPU:2"
|
||||
// CHECK: "test.opB"
|
||||
// CHECK-NOT: device
|
||||
// CHECK: "test.opA"
|
||||
// CHECK-SAME: device = "/TPU:2"
|
||||
// CHECK: "test.opB"
|
||||
// CHECK-NOT: device
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK: device = "/CPU:0"
|
||||
// CHECK: "tf.opA"
|
||||
// CHECK: device = "/GPU:1"
|
||||
|
||||
|
||||
// Tests unused per replica island are added as a control dependency to the
|
||||
@ -104,7 +108,7 @@ func @unused_replica_control(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
%0 = tf_executor.graph {
|
||||
%1 = tf_executor.ControlTrigger {}
|
||||
%2:2 = tf_executor.island(%1) {
|
||||
%3:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>) {n = 2 : i32, devices = {CORE_0 = ["/CPU:0", "/GPU:1"]}} {
|
||||
%3:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>) {n = 2 : i32} {
|
||||
%4 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
|
||||
%5 = "tf.opB"(%4) : (tensor<i1>) -> tensor<i1>
|
||||
tf_device.return %4, %5 : tensor<i1>, tensor<i1>
|
||||
@ -119,15 +123,11 @@ func @unused_replica_control(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger
|
||||
// CHECK: %[[ISLAND_0:[a-z_0-9]*]]:2, %{{.*}} = tf_executor.island(%[[CT]])
|
||||
// CHECK: %[[OP_A_0:[0-9]*]] = "tf.opA"(%[[ARG_0]])
|
||||
// CHECK-SAME: device = "/CPU:0"
|
||||
// CHECK: %[[OP_B_0:[0-9]*]] = "tf.opB"(%[[OP_A_0]])
|
||||
// CHECK-SAME: device = "/CPU:0"
|
||||
// CHECK: tf_executor.yield %[[OP_A_0]], %[[OP_B_0]]
|
||||
// CHECK: %[[ISLAND_1:[a-z_0-9]*]]:2, %[[ISLAND_1_control:[a-z_0-9]*]] = tf_executor.island(%[[CT]])
|
||||
// CHECK: %[[OP_A_1:[0-9]*]] = "tf.opA"(%[[ARG_1]])
|
||||
// CHECK-SAME: device = "/GPU:1"
|
||||
// CHECK: %[[OP_B_1:[0-9]*]] = "tf.opB"(%[[OP_A_1]])
|
||||
// CHECK-SAME: device = "/GPU:1"
|
||||
// CHECK: tf_executor.yield %[[OP_A_1]], %[[OP_B_1]]
|
||||
// CHECK: %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island(%[[ISLAND_1_control]])
|
||||
// CHECK: tf_executor.yield %[[ISLAND_0]]#0
|
||||
|
@ -343,7 +343,7 @@ func @launch_with_loop() -> () {
|
||||
}
|
||||
func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>) {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// expected-error @+1 {{Resource used in while loop is only supported when the resource input and output alias each other in the loop body}}
|
||||
// expected-error @+1 {{resource used in while loop is only supported when the resource input and output alias each other in the loop body}}
|
||||
return %0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
@ -367,7 +367,7 @@ func @launch_with_loop() -> () {
|
||||
return
|
||||
}
|
||||
func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.resource<tensor<f32>>>) {
|
||||
// expected-error @+1 {{Found unsupported operations on resource.}}
|
||||
// expected-error @+1 {{found unsupported operations on resource.}}
|
||||
"tf._UnknownOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> ()
|
||||
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
}
|
||||
@ -399,7 +399,7 @@ func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> (tensor<*x!tf.re
|
||||
func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
|
||||
%read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%constant = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// expected-error @+1 {{Found resource write in loop condition.}}
|
||||
// expected-error @+1 {{found resource write in loop condition.}}
|
||||
"tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
|
||||
return %read : tensor<f32>
|
||||
}
|
||||
@ -524,7 +524,7 @@ func @launch_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
|
||||
%2 = "tf_device.launch"() ( {
|
||||
// expected-error @+1 {{Unsupported tf.IfOp output: resource does not alias a single input.}}
|
||||
// expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}}
|
||||
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
|
||||
output_shapes = ["tfshape$"], is_stateless = false}
|
||||
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
|
||||
@ -650,7 +650,7 @@ func @launch_with_stateful_partitioned_call() -> () {
|
||||
}) {device = "tpu0", launch_attr = "launch_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
// expected-error @+1 {{Unsupported function call: resource return value does not alias an input.}}
|
||||
// expected-error @+1 {{unsupported function call: resource return value does not alias an input.}}
|
||||
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %arg2: tensor<f32>) -> tensor<*x!tf.resource<tensor<f32>>> {
|
||||
%0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
return %0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||
|
@ -0,0 +1,254 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-stack-ops-decomposition | FileCheck %s -dump-input-on-failure
|
||||
|
||||
// Tests simple scalar stack operations without control flow.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> tensor<f32> {
|
||||
// CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor<i32>}
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<10xf32>>>
|
||||
// CHECK-NEXT: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<1xi32>>>
|
||||
// CHECK-NEXT: %[[ZERO:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[ZERO]])
|
||||
// CHECK-NEXT: %[[ZERO_SCALAR:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[CAST_ZERO:.*]] = "tf.Cast"(%[[ZERO_SCALAR]]) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[CONST10:.*]] = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[CAST_ZERO]], %[[CONST10]]) : (tensor<f32>, tensor<1xi32>) -> tensor<10xf32>
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[BROADCAST]])
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
%id = "tf.Identity"(%stack) : (tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK-NEXT: %[[PUSHVAL:.*]] = "tf._SomeOp"()
|
||||
%elem = "tf._SomeOp"() : () -> tensor<f32>
|
||||
// CHECK-NEXT: %[[READ_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
|
||||
// CHECK-NEXT: %[[READ_SIZE:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
|
||||
// CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[PUSHVAL]], %[[UPDATE_SHAPE]]) : (tensor<f32>, tensor<1xi32>) -> tensor<1xf32>
|
||||
// CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ_VAL]], %[[UPDATE_SLICE]], %[[READ_SIZE]]) : (tensor<10xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<10xf32>
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) : (tensor<!tf.resource<tensor<10xf32>>>, tensor<10xf32>) -> ()
|
||||
// CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: %[[NEW_SIZE:.*]] = "tf.AddV2"(%[[READ_SIZE]], %[[CONST1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[NEW_SIZE]]) : (tensor<!tf.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
|
||||
%push = "tf.StackPushV2"(%id, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
%pop = "tf.StackPopV2"(%stack) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
// CHECK-NEXT: %[[READ_VAL1:.*]] = "tf.ReadVariableOp"(%[[BUFFER]])
|
||||
// CHECK-NEXT: %[[READ_SIZE1:.*]] = "tf.ReadVariableOp"(%[[SIZE]])
|
||||
// CHECK-NEXT: %[[CONST1_1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: %[[SUB:.*]] = "tf.Sub"(%[[READ_SIZE1]], %[[CONST1_1]])
|
||||
// CHECK-NEXT: %[[SLICE_SIZE:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: %[[SLICE:.*]] = "tf.Slice"(%[[READ_VAL1]], %[[SUB]], %[[SLICE_SIZE]]) : (tensor<10xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xf32>
|
||||
// CHECK-NEXT: %[[ELEM_SHAPE:.*]] = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
// CHECK-NEXT: %[[ELEM:.*]] = "tf.Reshape"(%[[SLICE]], %[[ELEM_SHAPE]]) : (tensor<1xf32>, tensor<0xi32>) -> tensor<f32>
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[SUB]]) : (tensor<!tf.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
// CHECK-NEXT: return %[[ELEM]] : tensor<f32>
|
||||
return %pop : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests simple non-scalar stack operations without control flow.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[BUFFER:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<10x2xi32>>>
|
||||
// CHECK-NEXT: %[[SIZE:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<1xi32>>>
|
||||
// CHECK-NEXT: %[[ZERO_SIZE:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[ZERO_SIZE]]) : (tensor<!tf.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
|
||||
// CHECK-NEXT: %[[ZERO_CONST:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[STACK_SHAPE:.*]] = "tf.Const"() {value = dense<[10, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK-NEXT: %[[BROADCAST:.*]] = "tf.BroadcastTo"(%[[ZERO_CONST]], %[[STACK_SHAPE]]) : (tensor<i32>, tensor<2xi32>) -> tensor<10x2xi32>
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[BROADCAST]]) : (tensor<!tf.resource<tensor<10x2xi32>>>, tensor<10x2xi32>) -> ()
|
||||
%stack = "tf.StackV2"(%size) {elem_type = i32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
// CHECK-NEXT: %[[PUSH_VAL:.*]] = "tf._SomeOp"() : () -> tensor<2xi32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<2xi32>
|
||||
// CHECK-NEXT: %[[STACK_VAL:.*]] = "tf.ReadVariableOp"(%[[BUFFER]]) : (tensor<!tf.resource<tensor<10x2xi32>>>) -> tensor<10x2xi32>
|
||||
// CHECK-NEXT: %[[STACK_SIZE:.*]] = "tf.ReadVariableOp"(%[[SIZE]]) : (tensor<!tf.resource<tensor<1xi32>>>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: %[[UPDATE_SHAPE:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
// CHECK-NEXT: %[[UPDATE_SLICE:.*]] = "tf.Reshape"(%[[PUSH_VAL]], %[[UPDATE_SHAPE]]) : (tensor<2xi32>, tensor<2xi32>) -> tensor<1x2xi32>
|
||||
// CHECK-NEXT: %[[ZERO_INDS:.*]] = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: %[[CONCAT_DIM:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[CONCAT_OFFETS:.*]] = "tf.ConcatV2"(%[[STACK_SIZE]], %[[ZERO_INDS]], %[[CONCAT_DIM]]) : (tensor<1xi32>, tensor<1xi32>, tensor<i32>) -> tensor<2xi32>
|
||||
// CHECK-NEXT: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"(%[[STACK_VAL]], %[[UPDATE_SLICE]], %[[CONCAT_OFFETS]]) : (tensor<10x2xi32>, tensor<1x2xi32>, tensor<2xi32>) -> tensor<10x2xi32>
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[BUFFER]], %[[UPDATE]]) : (tensor<!tf.resource<tensor<10x2xi32>>>, tensor<10x2xi32>) -> ()
|
||||
// CHECK-NEXT: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: %[[NEW_SIZE:.*]] = "tf.AddV2"(%[[STACK_SIZE]], %[[CONST1]]) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[SIZE]], %[[NEW_SIZE]]) : (tensor<!tf.resource<tensor<1xi32>>>, tensor<1xi32>) -> ()
|
||||
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<2xi32>) -> tensor<2xi32>
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
// CHECK-NEXT: return %[[PUSH_VAL]] : tensor<2xi32>
|
||||
return %push : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests while loop.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NOT: tf.Stack
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
%1:2 = "tf.While"(%stack, %max_size) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<!tf.resource>, tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>)
|
||||
// CHECK: "tf.Slice"
|
||||
%pop = "tf.StackPopV2"(%1#0) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
// CHECK-NOT: tf.Stack
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
// CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
|
||||
func @while_body(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>) {
|
||||
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
|
||||
%sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<f32>
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
// CHECK: "tf.XlaDynamicUpdateSlice"
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]]
|
||||
return %arg0, %sub : tensor<!tf.resource>, tensor<i32>
|
||||
}
|
||||
// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
|
||||
func @while_cond(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: return %[[CARG1]]
|
||||
return %arg1 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests IfOp.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main(%arg0: tensor<i1>) -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NOT: tf.Stack
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
%if_op = "tf.If"(%arg0, %stack) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
|
||||
: (tensor<i1>, tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: "tf.Slice"
|
||||
%pop = "tf.StackPopV2"(%if_op) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
// CHECK-NOT: tf.Stack
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
// CHECK: func @if_then(%[[TARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[TARG1:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
|
||||
func @if_then(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%elem = "tf._SomeOp"() : () -> tensor<f32>
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
|
||||
// CHECK: "tf.AssignVariableOp"(%[[TARG0:.*]], %[[UPDATE]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[EARG1:.*]],
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
// CHECK: func @if_else(%[[EARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[EARG1:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
|
||||
func @if_else(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
// CHECK: "tf.Slice"
|
||||
// CHECK: "tf.AssignVariableOp"(%[[EARG1:.*]],
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
%pop = "tf.StackPopV2"(%arg0) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests PartitionedCall/StatefulPartitionedCall.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main(%arg0: tensor<i1>) -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK-NOT: tf.Stack
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
// CHECK: "tf.StatefulPartitionedCall"
|
||||
// CHECK-SAME: f = @callee_stack_decomposed
|
||||
%call = "tf.StatefulPartitionedCall"(%stack, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<!tf.resource>, tensor<i1>) -> tensor<!tf.resource>
|
||||
// CHECK: "tf.PartitionedCall"
|
||||
// CHECK-SAME: f = @callee_stack_decomposed
|
||||
%call2 = "tf.PartitionedCall"(%stack, %arg0) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<!tf.resource>, tensor<i1>) -> tensor<!tf.resource>
|
||||
// CHECK: "tf.Slice"
|
||||
%pop = "tf.StackPopV2"(%call) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
// CHECK-NOT: tf.Stack
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @callee(%[[AARG0:.*]]: tensor<!tf.resource>, %[[AARG1:.*]]: tensor<i1>) -> tensor<!tf.resource>
|
||||
func @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> {
|
||||
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
|
||||
// CHECK: tf.StackPushV2"
|
||||
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
|
||||
// CHECK: func @callee_stack_decomposed(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
|
||||
// CHECK: "tf.AssignVariableOp"(%[[TARG0:.*]], %[[UPDATE]])
|
||||
// CHECK: "tf.AssignVariableOp"(%[[EARG1:.*]],
|
||||
// CHECK-NOT: "tf.StackPushV2"
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on unknown stack size.
|
||||
|
||||
func @main(%arg0: tensor<i32>) -> tensor<2xi32> {
|
||||
// expected-error @+1 {{max size of stack is not a constant.}}
|
||||
%stack = "tf.StackV2"(%arg0) {elem_type = i32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<2xi32>
|
||||
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<2xi32>) -> tensor<2xi32>
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
return %push : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on unknown element shape.
|
||||
|
||||
func @main(%arg0: tensor<i32>) -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{cannot infer element shape of stack.}}
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = i32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<*xi32>
|
||||
%push = "tf.StackPushV2"(%stack, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<*xi32>) -> tensor<*xi32>
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on ambiguous stack.
|
||||
|
||||
func @main(%arg0: tensor<i1>) -> () {
|
||||
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
%stack2 = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s2"} : (tensor<i32>) -> tensor<!tf.resource>
|
||||
%if_op = "tf.If"(%arg0, %stack, %stack2) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
|
||||
: (tensor<i1>, tensor<!tf.resource>, tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// expected-error @+1 {{unknown stack.}}
|
||||
%pop = "tf.StackPopV2"(%if_op) : (tensor<!tf.resource>) -> tensor<f32>
|
||||
"tf.StackCloseV2"(%stack) : (tensor<!tf.resource>) -> ()
|
||||
// CHECK: return
|
||||
return
|
||||
}
|
||||
func @if_then(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%elem = "tf._SomeOp"() : () -> tensor<f32>
|
||||
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
func @if_else(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%elem = "tf._SomeOp"() : () -> tensor<f32>
|
||||
%push = "tf.StackPushV2"(%arg1, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
|
||||
return %arg1 : tensor<!tf.resource>
|
||||
}
|
@ -4,12 +4,16 @@
|
||||
|
||||
// CHECK: func @non_replicated(%[[ARG0:.*]]: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32>
|
||||
func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"()
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
|
||||
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
|
||||
// CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext"
|
||||
@ -17,11 +21,20 @@ func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
|
||||
// CHECK-DAG: %[[COPY1:.*]] = "tf.TPUCopyWithLayout"(%[[ITER]]#1, %[[LAYOUT1]]) {device = "/device:TPU:0"}
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
// CHECK: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1) {device = "/device:TPU:0"}
|
||||
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
return %3 : tensor<i32>
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1)
|
||||
%execute = "tf_device.launch"() ( {
|
||||
%3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %3 : tensor<i32>
|
||||
}) {device = "/device:TPU:0"} : () -> tensor<i32>
|
||||
return %execute : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -31,22 +44,34 @@ func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}
|
||||
|
||||
// CHECK-LABEL: func @multiple_compile_uses
|
||||
func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> tensor<i32> {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-NOT: "tf.TPUGetLayoutOp"
|
||||
// CHECK-NOT: "tf.TPUCopyWithLayout"
|
||||
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
%execute0 = "tf_device.launch"() ( {
|
||||
%3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %3 : tensor<i32>
|
||||
}) {device = "/device:TPU:0"} : () -> tensor<i32>
|
||||
%4:2 = "tf._UnKnownOp_"() : () -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
%5 = "tf.TPUExecute"(%4#0, %4#1, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
return %5 : tensor<i32>
|
||||
%execute1 = "tf_device.launch"() ( {
|
||||
%5 = "tf.TPUExecute"(%4#0, %4#1, %compile#1)
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %5 : tensor<i32>
|
||||
}) {device = "/device:TPU:0"} : () -> tensor<i32>
|
||||
return %execute1 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -55,19 +80,28 @@ func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device:
|
||||
|
||||
// CHECK-LABEL: func @on_tpu_iter
|
||||
func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) -> tensor<i32> {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-NOT: "tf.TPUGetLayoutOp"
|
||||
// CHECK-NOT: "tf.TPUCopyWithLayout"
|
||||
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:TPU:0"}
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
%3 = "tf.TPUExecute"(%2#0, %2#1, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
return %3 : tensor<i32>
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
%execute = "tf_device.launch"() ( {
|
||||
%3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1)
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %3 : tensor<i32>
|
||||
}) {device = "/device:TPU:0"} : () -> tensor<i32>
|
||||
return %execute : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -76,18 +110,27 @@ func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) -
|
||||
|
||||
// CHECK-LABEL: func @unsupported_ops
|
||||
func @unsupported_ops(%arg0: tensor<3x3x1x32xf32>) -> tensor<i32> {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-NOT: "tf.TPUGetLayoutOp"
|
||||
// CHECK-NOT: "tf.TPUCopyWithLayout"
|
||||
%2 = "tf._Unknown_"() : () -> tensor<3x3x1x32xf32>
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
%3 = "tf.TPUExecute"(%arg0, %2, %1#1) {device = "/device:TPU:0"}
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
return %3 : tensor<i32>
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
%execute = "tf_device.launch"() ( {
|
||||
%3 = "tf.TPUExecute"(%arg0, %2, %compile#1)
|
||||
: (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %3 : tensor<i32>
|
||||
}) {device = "/device:TPU:0"} : () -> tensor<i32>
|
||||
return %execute : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -99,12 +142,16 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) ->
|
||||
// CHECK: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"
|
||||
%2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"}
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"()
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false}
|
||||
// CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false}
|
||||
// CHECK-DAG: %[[COPY0:.*]] = "tf.TPUCopyWithLayout"(%[[ITER0]]#0, %[[LAYOUT0]]) {device = "/device:TPU:0"}
|
||||
@ -114,13 +161,19 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) ->
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
// CHECK-DAG: %[[COPY2:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#0, %[[LAYOUT0]]) {device = "/device:TPU:1"}
|
||||
// CHECK-DAG: %[[COPY3:.*]] = "tf.TPUCopyWithLayout"(%[[ITER1]]#1, %[[LAYOUT1]]) {device = "/device:TPU:1"}
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
// CHECK: tf_device.replicate([%[[COPY0]], %[[COPY2]]] as %[[R0:.*]]: tensor<3x3x1x32xf32>, [%[[COPY1]], %[[COPY3]]] as %[[R1:.*]]: tensor<3x3x1x32xf32>)
|
||||
%5:2 = tf_device.replicate([%2#0, %3#0] as %r0: tensor<3x3x1x32xf32>, [%2#1, %3#1] as %r1: tensor<3x3x1x32xf32>)
|
||||
{n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} {
|
||||
// CHECK: "tf.TPUExecute"(%[[R0]], %[[R1]], %[[COMPILE]]#1)
|
||||
%4 = "tf.TPUExecute"(%r0, %r1, %1#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %4 : tensor<i32>
|
||||
%execute = "tf_device.launch"() ( {
|
||||
%4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %4 : tensor<i32>
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i32>
|
||||
tf_device.return %execute : tensor<i32>
|
||||
}
|
||||
return %5#0 : tensor<i32>
|
||||
}
|
||||
@ -131,20 +184,29 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) ->
|
||||
|
||||
// CHECK-LABEL: func @inside_replicated
|
||||
func @inside_replicated(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> tensor<i32> {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%1:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
tf_device.return %1#0, %1#1 : tensor<!tf.string>, tensor<!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
// CHECK-NOT: "tf.TPUGetLayoutOp"
|
||||
// CHECK-NOT: "tf.TPUCopyWithLayout"
|
||||
"tf.TPUCompileSucceededAssert"(%1#0) : (tensor<!tf.string>) -> ()
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
%5:2 = tf_device.replicate([%arg0, %arg1] as %r0: tensor<*x!tf.resource>)
|
||||
{n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} {
|
||||
%2:2 = "tf.IteratorGetNext"(%r0)
|
||||
: (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>)
|
||||
%4 = "tf.TPUExecute"(%2#0, %2#1, %1#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %4 : tensor<i32>
|
||||
%execute = "tf_device.launch"() ( {
|
||||
%4 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<!tf.string>) -> tensor<i32>
|
||||
tf_device.return %4 : tensor<i32>
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor<i32>
|
||||
tf_device.return %execute : tensor<i32>
|
||||
}
|
||||
return %5#0 : tensor<i32>
|
||||
}
|
||||
|
@ -22,16 +22,21 @@ func @merge_same_device_variables(
|
||||
%read0 = "tf.ReadVariableOp"(%id0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
|
||||
%read2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource<tensor<16xf32>>>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: %[[EXE:.*]] = "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[ARG_3]])
|
||||
// CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[ARG_3]])
|
||||
// CHECK-SAME: device_var_reads_indices = [0, 1],
|
||||
// CHECK-SAME: device_var_updates_indices = [0, -1]
|
||||
%execute:2 = "tf.TPUExecute"(%read0, %read1, %read2, %arg3) {
|
||||
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<16xf32>],
|
||||
Tresults = [tensor<32xf32>, tensor<16xf32>],
|
||||
device = "/job:localhost/replica:0/task:0/device:TPU:0"}
|
||||
: (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<16xf32>)
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_2]], %[[EXE]])
|
||||
%execute:2 = "tf_device.launch"() ( {
|
||||
%0:2 = "tf.TPUExecute"(%read0, %read1, %read2, %arg3) {
|
||||
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<16xf32>],
|
||||
Tresults = [tensor<32xf32>, tensor<16xf32>]}
|
||||
: (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<16xf32>)
|
||||
tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<16xf32>
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<16xf32>)
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
|
||||
"tf.AssignVariableOp"(%id0, %execute#0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_2]], %[[EXE]])
|
||||
"tf.AssignVariableOp"(%arg2, %execute#1) : (tensor<*x!tf.resource<tensor<16xf32>>>, tensor<16xf32>) -> ()
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
tf_executor.yield
|
||||
@ -61,12 +66,18 @@ func @merge_replicated_variables(
|
||||
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// CHECK-NEXT: tf_device.replicate([%[[ARG_2]], %[[ARG_3]]] as %[[R_ARG:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>)
|
||||
tf_device.replicate([%arg2, %arg3] as %r: tensor<*x!tf.resource<tensor<32xf32>>>) {n = 2 : i32} {
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[R_ARG]], %[[ARG_1]])
|
||||
// CHECK-SAME: device_var_reads_indices = [1],
|
||||
// CHECK-SAME: device_var_updates_indices = [0]
|
||||
%read1 = "tf.ReadVariableOp"(%r) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%execute = "tf.TPUExecute"(%read0, %read1, %arg1)
|
||||
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> tensor<32xf32>
|
||||
%execute = "tf_device.launch"() ( {
|
||||
%0 = "tf.TPUExecute"(%read0, %read1, %arg1)
|
||||
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> tensor<32xf32>
|
||||
tf_device.return %0 : tensor<32xf32>
|
||||
}) {device = ""} : () -> tensor<32xf32>
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: }) {device = ""}
|
||||
"tf.AssignVariableOp"(%r, %execute) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// CHECK-NEXT: tf_device.return
|
||||
tf_device.return
|
||||
@ -114,15 +125,20 @@ func @interferencing_accesses(
|
||||
"tf.AssignVariableOp"(%arg5, %arg6) : (tensor<*x!tf.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
|
||||
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
|
||||
%read2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[ARG_3]])
|
||||
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[ARG_3]])
|
||||
// CHECK-SAME: device_var_reads_indices = [1, 2],
|
||||
// CHECK-SAME: device_var_updates_indices = [1, -1]
|
||||
%execute:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %arg3) {
|
||||
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>],
|
||||
Tresults = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>],
|
||||
device = "/job:localhost/replica:0/task:0/device:TPU:0"}
|
||||
: (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor<!tf.string>)
|
||||
-> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
|
||||
%execute:3 = "tf_device.launch"() ( {
|
||||
%0:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %arg3) {
|
||||
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>],
|
||||
Tresults = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>]}
|
||||
: (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor<!tf.string>)
|
||||
-> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
|
||||
tf_device.return %0#0, %0#1, %0#2 : tensor<32xf32>, tensor<64xf32>, tensor<8xf32>
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
|
||||
"tf.AssignVariableOp"(%arg1, %execute#1) : (tensor<*x!tf.resource<tensor<64xf32>>>, tensor<64xf32>) -> ()
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0)
|
||||
"tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
@ -156,11 +172,16 @@ func @do_not_merge_multi_read(
|
||||
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// CHECK-NEXT: %[[READ_1:.*]] = "tf.ReadVariableOp"(%[[ARG_0]])
|
||||
%read1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// CHECK-NEXT: %[[EXE:.*]] = "tf.TPUExecute"(%[[READ_0]], %[[READ_1]], %[[ARG_1]])
|
||||
%execute = "tf.TPUExecute"(%read0, %read1, %arg1) {
|
||||
Targs = [tensor<32xf32>, tensor<32xf32>], Tresults = [tensor<32xf32>],
|
||||
device = "/job:localhost/replica:0/task:0/device:TPU:0"}
|
||||
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> (tensor<32xf32>)
|
||||
// CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[READ_1]], %[[ARG_1]])
|
||||
%execute = "tf_device.launch"() ( {
|
||||
%0 = "tf.TPUExecute"(%read0, %read1, %arg1) {
|
||||
Targs = [tensor<32xf32>, tensor<32xf32>], Tresults = [tensor<32xf32>]}
|
||||
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> (tensor<32xf32>)
|
||||
tf_device.return %0 : tensor<32xf32>
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<32xf32>
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]])
|
||||
"tf.AssignVariableOp"(%arg0, %execute) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
@ -187,11 +208,16 @@ func @do_not_merge_multi_assign(
|
||||
%island = tf_executor.island {
|
||||
// CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]])
|
||||
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf.TPUExecute"(%[[READ_0]], %[[ARG_1]])
|
||||
%execute:2 = "tf.TPUExecute"(%read0, %arg1) {
|
||||
Targs = [tensor<32xf32>], Tresults = [tensor<32xf32>, tensor<32xf32>],
|
||||
device = "/job:localhost/replica:0/task:0/device:TPU:0"}
|
||||
: (tensor<32xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<32xf32>)
|
||||
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[READ_0]], %[[ARG_1]])
|
||||
%execute:2 = "tf_device.launch"() ( {
|
||||
%0:2 = "tf.TPUExecute"(%read0, %arg1) {
|
||||
Targs = [tensor<32xf32>], Tresults = [tensor<32xf32>, tensor<32xf32>]}
|
||||
: (tensor<32xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<32xf32>)
|
||||
tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<32xf32>
|
||||
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<32xf32>)
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#0)
|
||||
"tf.AssignVariableOp"(%arg0, %execute#0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
|
||||
// CHECK-NEXT: "tf.AssignVariableOp"(%[[ARG_0]], %[[EXE]]#1)
|
||||
|
@ -31,7 +31,11 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
// CHECK-SAME: as %[[V0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
// CHECK-SAME: as %[[V1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
// CHECK-SAME: [%[[STATE0]], %[[STATE1]]] as %[[STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
|
||||
// CHECK: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[V0]], %[[V1]], %[[DEFAULT]], %[[STATE]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
|
||||
return
|
||||
}
|
||||
// CHECK: func @while_body_7560
|
||||
@ -51,27 +55,41 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
// CHECK-SAME: %[[STATE_ARG1:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>> {tf.device = "/device:TPU:1"})
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf._TPUCompileMlir"()
|
||||
%2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
"tf.TPUCompileSucceededAssert"(%2#0) : (tensor<!tf.string>) -> ()
|
||||
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"()
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[BODY_ARG1]], %[[BODY_ARG2]]] as %[[R0:.*]]: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
// CHECK-SAME: [%[[BODY_ARG3]], %[[BODY_ARG4]]] as %[[R1:.*]]: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
// CHECK-SAME: [%[[STATE_ARG0]], %[[STATE_ARG1]]] as %[[R_STATE:.*]]: tensor<!tf.resource<tensor<2x!tf.string>>>
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]
|
||||
%rep:2 = tf_device.replicate([%arg1, %arg2] as %arg30: tensor<*x!tf.resource<tensor<f32>>>,
|
||||
[%arg3, %arg4] as %arg31: tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>)
|
||||
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// CHECK: %[[ID:.*]] = "tf.Identity"(%[[R0]])
|
||||
%id = "tf.Identity"(%arg30) : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
// CHECK: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUReshardVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1, %[[R_STATE]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_CORE_0"
|
||||
// CHECK: "tf.TPUExecuteAndUpdateVariables"(%[[ID]], %[[R1]], %[[COMPILE]]#1)
|
||||
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %2#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>, tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
%ret = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
tf_device.return %ret : tensor<i32>
|
||||
}
|
||||
@ -131,12 +149,18 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<f32>>>) {
|
||||
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64, device = "/device:CPU:0",
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
"tf.TPUCompileSucceededAssert"(%2#0) : (tensor<!tf.string>) -> ()
|
||||
%compile:2 = "tf_device.launch"() ( {
|
||||
%2:2 = "tf._TPUCompileMlir"() {
|
||||
NumDynamicShapes = 0 : i64,
|
||||
// The metadata encodes 2 parameter and two return values.
|
||||
metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01",
|
||||
mlir_module = "..."} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
tf_device.return %2#0, %2#1 : tensor<!tf.string>, tensor<!tf.string>
|
||||
}) {device = "/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<!tf.string>)
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUCompileSucceededAssert"(%compile#0) : (tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "/device:CPU:0"} : () -> ()
|
||||
%id0 = "tf.Identity"(%arg3) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>
|
||||
"tf._Unknown_"(%id0) : (tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>) -> ()
|
||||
%newvar = "tf._SomeOp"() : () -> tensor<*x!tf.resource<tensor<f32>>>
|
||||
@ -146,10 +170,13 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
||||
{_mirrored_variable_indices = [0, 1], devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}, n = 2 : i32} {
|
||||
// %arg30 is used in the cond function, %arg31 has other uses (%id0), and
|
||||
// %arg32 is not a pass-through.
|
||||
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %2#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<!tf.string>) -> ()
|
||||
"tf_device.launch"() ( {
|
||||
"tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1)
|
||||
{device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]}
|
||||
: (tensor<*x!tf.resource<tensor<f32>>>, tensor<*x!tf.resource<tensor<3x3x1x32xf32>>>,
|
||||
tensor<*x!tf.resource<tensor<f32>>>, tensor<!tf.string>) -> ()
|
||||
tf_device.return
|
||||
}) {device = "TPU_REPLICATED_CORE_0"} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return %1, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6 : tensor<i32>, tensor<*x!tf.resource<tensor<f32>>>,
|
||||
|
@ -431,16 +431,21 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK-SAME: device = "/job:worker/replica:0/task:0/device:TPU:0"
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
|
||||
|
||||
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
|
||||
@ -472,15 +477,20 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-SAME: n = 2
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[RI_0]], %[[COMPILE_OUTPUT]]#1)
|
||||
%2 = "tf_device.launch_func"(%ri_0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: tf_device.return %[[EXECUTE_OUTPUT]]
|
||||
@ -539,7 +549,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
@ -548,7 +559,13 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-SAME: func @nested_func
|
||||
// CHECK-SAME: tf.D
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:TPU:0"
|
||||
|
||||
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
|
||||
@ -582,7 +599,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
@ -591,7 +609,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-SAME: func @referenced_func
|
||||
// CHECK-SAME: tf.D
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
|
||||
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
|
||||
@ -624,7 +645,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
@ -635,7 +657,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-SAME: @referenced_func2
|
||||
// CHECK-SAME: tf.E
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
|
||||
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
|
||||
@ -674,7 +699,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
@ -684,7 +710,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-COUNT-1: func @referenced_func
|
||||
// CHECK-SAME: tf.D
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
|
||||
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
|
||||
@ -718,25 +747,33 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func0, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-NOT: func = @tpu0_func0
|
||||
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE0_OUTPUT]]#0)
|
||||
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
|
||||
|
||||
%2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func1, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.D
|
||||
// CHECK-NOT: func = @tpu0_func1
|
||||
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE1_OUTPUT]]#0)
|
||||
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
|
||||
|
||||
%3 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE1_OUTPUT]])
|
||||
@ -768,25 +805,33 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE0_OUTPUT]]#0)
|
||||
// CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1)
|
||||
|
||||
%2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]])
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[EXECUTE0_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
// CHECK-SAME: func @main
|
||||
// CHECK-SAME: tf.B
|
||||
// CHECK-NOT: func = @tpu0_func
|
||||
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE1_OUTPUT]]#0)
|
||||
// CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1)
|
||||
|
||||
%3 = "tf.C"(%2) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE1_OUTPUT]])
|
||||
@ -814,7 +859,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
%1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"(%[[A_SHAPE_OUTPUT]])
|
||||
// CHECK-SAME: NumDynamicShapes = 1
|
||||
// CHECK-SAME: metadata
|
||||
// CHECK-SAME: mlir_module
|
||||
@ -828,7 +874,10 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-SAME: tf.G
|
||||
// CHECK-SAME: func @referenced_func0
|
||||
// CHECK-SAME: tf.F
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1)
|
||||
|
||||
%2 = "tf.C"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]])
|
||||
@ -873,8 +922,12 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-LABEL: func @tpu_compilation_result
|
||||
func @tpu_compilation_result(%arg0: tensor<?xi32>) -> (tensor<?xi32>, tensor<!tf.string>, tensor<!tf.string>) {
|
||||
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"
|
||||
// CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"
|
||||
// CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.TPUExecute"
|
||||
%1 = "tf_device.launch_func"(%arg0) {_tpu_replicate = "cluster0", device = "", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "", padding_map = []} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
%compile_result = "tf.TPUCompilationResult"() {_tpu_replicate = "cluster0"} : () -> tensor<!tf.string>
|
||||
|
@ -0,0 +1,147 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-sharding-identification | FileCheck %s --dump-input=fail
|
||||
|
||||
// Tests empty launch func. Empty input/output sharding configuration
|
||||
// attributes must be added.
|
||||
// CHECK-LABEL: func @check_sharding_attrs_exists_for_empty_launch_func
|
||||
func @check_sharding_attrs_exists_for_empty_launch_func() {
|
||||
"tf_device.launch_func"() {device = "", func = @empty_func, step_marker_location = ""} : () -> ()
|
||||
// CHECK: input_sharding_configuration = []
|
||||
// CHECK: output_sharding_configuration = []
|
||||
return
|
||||
}
|
||||
|
||||
func @empty_func() {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests with a inputs/outputs with no xla sharding op attached gets
|
||||
// default maximal(0) sharding configuration.
|
||||
// CHECK-LABEL: func @check_default_sharding_for_inputs_outputs
|
||||
func @check_default_sharding_for_inputs_outputs(%arg0: tensor<*xi32>) {
|
||||
"tf_device.launch_func"(%arg0) {device = "", func = @func_without_sharding, step_marker_location = ""} : (tensor<*xi32>) -> ()
|
||||
// CHECK: input_sharding_configuration
|
||||
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
|
||||
// CHECK: output_sharding_configuration
|
||||
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
|
||||
return
|
||||
}
|
||||
|
||||
func @func_without_sharding(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests with a input arg connected to XlaSharding op.
|
||||
// CHECK-LABEL: func @check_sharding_for_input_correctly_identified
|
||||
func @check_sharding_for_input_correctly_identified(%arg0: tensor<*xi32>) {
|
||||
"tf_device.launch_func"(%arg0) {device = "", func = @inputs_with_sharding_func, step_marker_location = ""} : (tensor<*xi32>) -> ()
|
||||
// CHECK: input_sharding_configuration
|
||||
// CHECK-SAME: ["\01\02\03"]
|
||||
// CHECK: output_sharding_configuration
|
||||
// CHECK-SAME: ["\08\01\1A\01\01\22\01\00"]
|
||||
return
|
||||
}
|
||||
|
||||
func @inputs_with_sharding_func(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.A"(%0) : (tensor<*xi32>) -> (tensor<*xi32>)
|
||||
return %1 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests with sharding is correctly parsed for multiple inputs/outputs.
|
||||
// CHECK-LABEL: func @check_sharding_for_multiple_inputs_outputs
|
||||
func @check_sharding_for_multiple_inputs_outputs(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
// CHECK: input_sharding_configuration
|
||||
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
|
||||
// CHECK: output_sharding_configuration
|
||||
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
|
||||
return
|
||||
}
|
||||
|
||||
func @func_with_sharding(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.XlaSharding"(%arg0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%2, %3 = "tf.A"(%0, %1) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%4 = "tf.XlaSharding"(%2) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %4, %5 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests with input sharding following an identity op.
|
||||
// CHECK-LABEL: func @check_sharding_after_identity
|
||||
func @check_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_identity, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
// CHECK: input_sharding_configuration
|
||||
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
|
||||
// CHECK: output_sharding_configuration
|
||||
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
|
||||
return
|
||||
}
|
||||
|
||||
func @func_with_sharding_after_identity(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%2 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3, %4 = "tf.A"(%1, %2) : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%5 = "tf.XlaSharding"(%3) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %5, %6 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests with input sharding following a ReadVariable op.
|
||||
// CHECK-LABEL: func @check_sharding_after_read_variable
|
||||
func @check_sharding_after_read_variable(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_read_variable, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
// CHECK: input_sharding_configuration
|
||||
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
|
||||
// CHECK: output_sharding_configuration
|
||||
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
|
||||
return
|
||||
}
|
||||
|
||||
func @func_with_sharding_after_read_variable(%arg0: tensor<*x!tf.resource<tensor<32xf32>>>, %arg1: tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%1 = "tf.XlaSharding"(%0) { _XlaSharding = "\01\02\03" } : (tensor<32xf32>) -> tensor<32xf32>
|
||||
%2 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
|
||||
%3 = "tf.Identity"(%2) : (tensor<32xf32>) -> tensor<32xf32>
|
||||
%4 = "tf.XlaSharding"(%3) { _XlaSharding = "\04\05\06" } : (tensor<32xf32>) -> tensor<32xf32>
|
||||
%5, %6 = "tf.A"(%1, %3) : (tensor<32xf32>, tensor<32xf32>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%8 = "tf.XlaSharding"(%6) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %7, %8 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests with input sharding following an identity op and cast op.
|
||||
// CHECK-LABEL: func @check_sharding_after_cast_op
|
||||
func @check_sharding_after_cast_op(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
"tf_device.launch_func"(%arg0, %arg1) {device = "", func = @func_with_sharding_after_cast, step_marker_location = ""} : (tensor<*xi32>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
// CHECK: input_sharding_configuration
|
||||
// CHECK-SAME: ["\01\02\03", "\04\05\06"]
|
||||
// CHECK: output_sharding_configuration
|
||||
// CHECK-SAME: ["\0A\0B\0C", "\0D\0E\0F"]
|
||||
return
|
||||
}
|
||||
|
||||
func @func_with_sharding_after_cast(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>) {
|
||||
%0 = "tf.Identity"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%1 = "tf.Cast"(%0) : (tensor<*xi32>) -> tensor<*xi1>
|
||||
%2 = "tf.XlaSharding"(%1) { _XlaSharding = "\01\02\03" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%3 = "tf.XlaSharding"(%arg1) { _XlaSharding = "\04\05\06" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
%4, %5 = "tf.A"(%2, %3) : (tensor<*xi1>, tensor<*xi1>) -> (tensor<*xi32>, tensor<*xi1>)
|
||||
%6 = "tf.XlaSharding"(%4) { _XlaSharding = "\0A\0B\0C" } : (tensor<*xi32>) -> tensor<*xi32>
|
||||
%7 = "tf.XlaSharding"(%5) { _XlaSharding = "\0D\0E\0F" } : (tensor<*xi1>) -> tensor<*xi1>
|
||||
return %6, %7 : tensor<*xi32> , tensor<*xi1>
|
||||
}
|
@ -1,23 +1,23 @@
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-unroll-batch-matmul %s | FileCheck %s
|
||||
// RUN: tf-opt -split-input-file -verify-diagnostics -tf-unroll-batch-matmul %s | FileCheck %s
|
||||
|
||||
func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2TwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
@ -67,16 +67,16 @@ func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>)
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2FlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>}
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
@ -122,19 +122,19 @@ func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>)
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulTwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[6, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[6, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[3, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[4, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[5, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_9:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_10:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_11:.*]] = "tf.Const"() {value = dense<[2, 3, 4, 6]> : tensor<4xi64>}
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
@ -184,16 +184,16 @@ func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulFlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>}
|
||||
// CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>}
|
||||
// CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>}
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
@ -25,6 +25,17 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
// Add logger to bridge passmanager.
|
||||
void EnableLogging(PassManager *pm) {
|
||||
// Print the whole module after each pass, which requires disabling
|
||||
// multi-threading as well.
|
||||
pm->disableMultithreading();
|
||||
pm->enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
|
||||
/*print_module_scope=*/true));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace TFTPU {
|
||||
namespace {
|
||||
void AddGraphExportLoweringPasses(OpPassManager &pm) {
|
||||
@ -32,16 +43,14 @@ void AddGraphExportLoweringPasses(OpPassManager &pm) {
|
||||
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
|
||||
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateToIslandPass());
|
||||
pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass());
|
||||
pm.addNestedPass<FuncOp>(TFDevice::CreateLaunchToDeviceAttributePass());
|
||||
}
|
||||
|
||||
tensorflow::Status RunTPUBridge(
|
||||
ModuleOp module, bool enable_logging,
|
||||
llvm::function_ref<void(OpPassManager &pm)> pipeline_builder) {
|
||||
PassManager bridge(module.getContext());
|
||||
|
||||
// Add logger to bridge passmanager.
|
||||
if (enable_logging)
|
||||
bridge.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>());
|
||||
if (enable_logging) EnableLogging(&bridge);
|
||||
|
||||
// Populate a passmanager with the list of passes that implement the bridge.
|
||||
pipeline_builder(bridge);
|
||||
@ -117,10 +126,7 @@ tensorflow::Status RunBridgeWithStandardPipeline(ModuleOp module,
|
||||
bool enable_logging,
|
||||
bool enable_inliner) {
|
||||
PassManager bridge(module.getContext());
|
||||
|
||||
// Add logger to bridge passmanager.
|
||||
if (enable_logging)
|
||||
bridge.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>());
|
||||
if (enable_logging) EnableLogging(&bridge);
|
||||
|
||||
StandardPipelineOptions pipeline_options;
|
||||
pipeline_options.enable_inliner.setValue(enable_inliner);
|
||||
|
@ -43,6 +43,9 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateMaterializePassthroughOpPass();
|
||||
// Performs Shape Inference on the TensorFlow dialect using the global registry.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTFShapeInferencePass();
|
||||
|
||||
// Optional pass which will unroll BatchMatMul and use only MatMul
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateUnrollBatchMatMulPassPass();
|
||||
|
||||
// Optimizes Tensorflow graph.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateTFOptimizePass();
|
||||
|
||||
@ -99,6 +102,11 @@ std::unique_ptr<OpPassBase<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
||||
// Performs resource lifting on the function body to hoist resource variable
|
||||
// accesses outside all control flow statements.
|
||||
LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function);
|
||||
|
||||
// Converts stack ops into operations on local variables, which can later be
|
||||
// removed by resource lifting. Requires known maximum sizes of stacks and
|
||||
// known element shapes of push ops.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateStackOpsDecompositionPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace TFControlFlow {
|
||||
@ -209,6 +217,10 @@ std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUDynamicPaddingMapperPass();
|
||||
// ops.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass();
|
||||
|
||||
// Creates a pass that identifies XLASharding ops in launch op for TPU
|
||||
// computation.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass();
|
||||
|
||||
// Creates a pass that merges device variable reads/updates into the surrounded
|
||||
// TPUExecute node. This allows the execute node to perform in-place variable
|
||||
// updates.
|
||||
|
@ -47,22 +47,10 @@ struct ReplicateToIslandPass : public FunctionPass<ReplicateToIslandPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Get the device name for which replica_index-th block will execute.
|
||||
llvm::StringRef GetDeviceNameFromAttribute(DictionaryAttr devices,
|
||||
int replica_index) {
|
||||
// TODO(b/148913020): Remove this constraint once model parallelism is
|
||||
// supported.
|
||||
DCHECK_EQ(devices.size(), 1);
|
||||
Attribute device_attr = devices.begin()->second;
|
||||
return device_attr.cast<ArrayAttr>()
|
||||
.getValue()[replica_index]
|
||||
.cast<StringAttr>()
|
||||
.getValue();
|
||||
}
|
||||
|
||||
// Creates islands per replica from `tf_device.replicate` region. TensorFlow ops
|
||||
// will have their device set to the replica if they originally did not have a
|
||||
// device assigned.
|
||||
// Creates islands per replica from `tf_device.replicate` region. If for a
|
||||
// `tf_device.launch` op the device is an aliased device of the
|
||||
// `tf_device.replicate`, the device will be remapped to an explicit device
|
||||
// for the associated replica island.
|
||||
llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
|
||||
const Dialect* tf_dialect, OpBuilder* builder,
|
||||
tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op,
|
||||
@ -87,10 +75,6 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
|
||||
builder->setInsertionPoint(island_op);
|
||||
BlockAndValueMapping mapping;
|
||||
for (int i : llvm::seq<int>(0, num_replicas)) {
|
||||
// Determine optional device.
|
||||
llvm::StringRef device =
|
||||
has_devices ? GetDeviceNameFromAttribute(devices.getValue(), i) : "";
|
||||
|
||||
// Create new island for replica.
|
||||
auto replica = builder->create<tf_executor::IslandOp>(
|
||||
island_op.getLoc(), output_types, control_type, replica_inputs);
|
||||
@ -104,13 +88,13 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
|
||||
// Copy over replicate region into replica island.
|
||||
replicate_op.body().cloneInto(&replica.body(), mapping);
|
||||
|
||||
// Assign all TF ops in island optional device, if device is set.
|
||||
if (!device.empty()) {
|
||||
StringAttr device_attr = builder->getStringAttr(device);
|
||||
replica.walk([&](Operation* op) {
|
||||
if (op->getDialect() != tf_dialect) return;
|
||||
|
||||
if (!op->getAttr(kDeviceAttr)) op->setAttr(kDeviceAttr, device_attr);
|
||||
// Map aliased devices to explicit devices based on replica.
|
||||
if (has_devices) {
|
||||
replica.walk([&](tf_device::LaunchOp launch) {
|
||||
if (auto device_by_replica = devices.getValue().get(launch.device()))
|
||||
launch.setAttr(
|
||||
kDeviceAttr,
|
||||
device_by_replica.cast<ArrayAttr>()[i].cast<StringAttr>());
|
||||
});
|
||||
}
|
||||
|
||||
@ -124,16 +108,25 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
|
||||
// replicate results with new island outputs. A single island is created to
|
||||
// forward results from each replica island. Control dependencies of individual
|
||||
// replicas are added to the single island if the single island does not emit
|
||||
// a result from the respective replica.
|
||||
// a result from the respective replica. Devices are remapped from aliased
|
||||
// devices to explicit devices, for `tf_device.launch` ops.
|
||||
//
|
||||
// For example, the following:
|
||||
//
|
||||
// %0:2 = tf_executor.island(%control) {
|
||||
// %1:4 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<i1>)
|
||||
// {n = 2 : i32, devices = ["/CPU:0", "/GPU:1"]} {
|
||||
// %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
|
||||
// %3 = "tf.opB"(%2) : (tensor<i1>) -> tensor<i1>
|
||||
// tf_device.return %2, %3 : tensor<i1>, tensor<i1>
|
||||
// {n = 2 : i32,
|
||||
// devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
|
||||
// DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
|
||||
// %a = "tf_device.launch"() ( {
|
||||
// %2 = "tf.opA"(%ri) : (tensor<i1>) -> tensor<i1>
|
||||
// tf_device.return %2 : tensor<i1>
|
||||
// }) {device = "DEVICE_ALIAS_0"} : () -> tensor<i1>
|
||||
// %b = "tf_device.launch"() ( {
|
||||
// %3 = "tf.opB"(%a) : (tensor<i1>) -> tensor<i1>
|
||||
// tf_device.return %3 : tensor<i1>
|
||||
// }) {device = "DEVICE_ALIAS_1"} : () -> tensor<i1>
|
||||
// tf_device.return %a, %b : tensor<i1>, tensor<i1>
|
||||
// }
|
||||
// tf_executor.yield %1#0 : tensor<i1>
|
||||
// }
|
||||
@ -141,14 +134,26 @@ llvm::SmallVector<tf_executor::IslandOp, 8> ExpandReplicateIntoReplicas(
|
||||
// gets lowered to:
|
||||
//
|
||||
// %0:3 = tf_executor.island(%control) {
|
||||
// %1 = "tf.opA"(%arg0) {device = "/CPU:0"} : (tensor<i1>) -> tensor<i1>
|
||||
// %2 = "tf.opB"(%1) {device = "/CPU:0"} : (tensor<i1>) -> tensor<i1>
|
||||
// tf_executor.yield %1, %2 : tensor<i1>, tensor<i1>
|
||||
// %a0 = "tf_device.launch"() ( {
|
||||
// %1 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
|
||||
// tf_device.return %1 : tensor<i1>
|
||||
// }) {device = "/DEVICE:0"} : () -> tensor<i1>
|
||||
// %b0 = "tf_device.launch"() ( {
|
||||
// %2 = "tf.opB"(%a0) : (tensor<i1>) -> tensor<i1>
|
||||
// tf_device.return %2 : tensor<i1>
|
||||
// }) {device = "/DEVICE:2"} : () -> tensor<i1>
|
||||
// tf_executor.yield %a0, %b0 : tensor<i1>, tensor<i1>
|
||||
// }
|
||||
// %3:3 = tf_executor.island(%control) {
|
||||
// %4 = "tf.opA"(%arg1) {device = "/GPU:1"} : (tensor<i1>) -> tensor<i1>
|
||||
// %5 = "tf.opB"(%4) {device = "/GPU:1"} : (tensor<i1>) -> tensor<i1>
|
||||
// tf_executor.yield %4, %5 : tensor<i1>, tensor<i1>
|
||||
// %a1 = "tf_device.launch"() ( {
|
||||
// %4 = "tf.opA"(%arg1) : (tensor<i1>) -> tensor<i1>
|
||||
// tf_device.return %4 : tensor<i1>
|
||||
// }) {device = "/DEVICE:1"} : () -> tensor<i1>
|
||||
// %b1 = "tf_device.launch"() ( {
|
||||
// %5 = "tf.opB"(%a1) : (tensor<i1>) -> tensor<i1>
|
||||
// tf_device.return %5 : tensor<i1>
|
||||
// }) {device = "/DEVICE:3"} : () -> tensor<i1>
|
||||
// tf_executor.yield %a1, %b1 : tensor<i1>, tensor<i1>
|
||||
// }
|
||||
// %6:2 = tf_executor.island(%3#2) {
|
||||
// tf_executor.yield %0#0 : tensor<i1>
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
@ -365,9 +366,10 @@ LogicalResult FindResourceArgUseInfo(
|
||||
auto return_op = func_op.front().getTerminator();
|
||||
for (auto arg : func_op.getArguments()) {
|
||||
if (!getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) continue;
|
||||
auto& info = (*result)[arg.getArgNumber()];
|
||||
ResourceArgUseInfo info;
|
||||
info.used = false;
|
||||
info.updated = false;
|
||||
bool do_not_touch = false;
|
||||
for (auto user : arg.getUsers()) {
|
||||
if (user == return_op) continue;
|
||||
if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
|
||||
@ -381,9 +383,16 @@ LogicalResult FindResourceArgUseInfo(
|
||||
info.data_type = assign.value().getType();
|
||||
continue;
|
||||
}
|
||||
user->emitError("Found unsupported operations on resource.");
|
||||
if (llvm::isa<TF::StackPushV2Op>(user) ||
|
||||
llvm::isa<TF::StackPopV2Op>(user)) {
|
||||
// Stacks will be handled by a separate pass.
|
||||
do_not_touch = true;
|
||||
break;
|
||||
}
|
||||
user->emitOpError("found unsupported operations on resource.");
|
||||
return failure();
|
||||
}
|
||||
if (!do_not_touch) (*result)[arg.getArgNumber()] = info;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@ -395,14 +404,18 @@ LogicalResult FindResourceArgUseInfo(
|
||||
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo(
|
||||
const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0,
|
||||
const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) {
|
||||
auto result = infos0;
|
||||
for (auto& entry : result) {
|
||||
if (entry.getSecond().used) continue;
|
||||
auto& info1_entry = *infos1.find(entry.getFirst());
|
||||
if (info1_entry.getSecond().used) {
|
||||
entry.getSecond().used = true;
|
||||
entry.getSecond().updated |= info1_entry.getSecond().updated;
|
||||
entry.getSecond().data_type = info1_entry.getSecond().data_type;
|
||||
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> result;
|
||||
for (const auto& entry : infos0) {
|
||||
auto info1_it = infos1.find(entry.getFirst());
|
||||
// If the entry is missing in any input, we should not touch this entry.
|
||||
if (info1_it == infos1.end()) continue;
|
||||
auto& info = result[entry.getFirst()];
|
||||
info = entry.getSecond();
|
||||
if (info.updated) continue;
|
||||
if (info1_it->getSecond().used) {
|
||||
info.used = true;
|
||||
info.updated = info1_it->getSecond().updated;
|
||||
info.data_type = info1_it->getSecond().data_type;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
@ -576,16 +589,16 @@ LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
for (auto arg : body.getArguments()) {
|
||||
if (!getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) continue;
|
||||
if (return_op->getOperand(arg.getArgNumber()) != arg) {
|
||||
return return_op->emitError(
|
||||
"Resource used in while loop is only supported when the resource "
|
||||
"input and output alias each other in the loop body.");
|
||||
return return_op->emitOpError(
|
||||
"resource used in while loop is only supported when the ")
|
||||
<< "resource input and output alias each other in the loop body.";
|
||||
}
|
||||
}
|
||||
// FindResourceArgUseInfo will check supported resource ops (read and assign),
|
||||
// but loop condition has additional requirement that it cannot write
|
||||
// resources.
|
||||
if (cond.walk([&](TF::AssignVariableOp assign) {
|
||||
assign.emitError("Found resource write in loop condition.");
|
||||
assign.emitOpError("found resource write in loop condition.");
|
||||
return WalkResult::interrupt();
|
||||
})
|
||||
.wasInterrupted()) {
|
||||
@ -668,7 +681,7 @@ LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
|
||||
}
|
||||
|
||||
// Lifts loads/stores from an IfOp's branches.
|
||||
LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
FuncOp else_branch) {
|
||||
// Remove identity nodes to avoid aliasing.
|
||||
RemoveIdentity(&then_branch.front());
|
||||
@ -689,9 +702,8 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
auto else_aliasing_arg = else_retval.dyn_cast<BlockArgument>();
|
||||
if (!then_aliasing_arg || !else_aliasing_arg ||
|
||||
then_aliasing_arg.getArgNumber() != else_aliasing_arg.getArgNumber()) {
|
||||
return if_op.emitOpError(
|
||||
"Unsupported tf.IfOp output: resource does not alias a single "
|
||||
"input.");
|
||||
return if_op.emitOpError("unsupported tf.IfOp output: ")
|
||||
<< "resource does not alias a single input.";
|
||||
}
|
||||
if_op.getResult(entry.index())
|
||||
.replaceAllUsesWith(
|
||||
@ -701,6 +713,7 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
int64_t non_resource_results = 0;
|
||||
llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
|
||||
llvm::SmallVector<Attribute, 4> new_output_shapes;
|
||||
bool output_removed = false;
|
||||
for (auto result : if_op.getResults()) {
|
||||
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
|
||||
old_to_new_output_indices.push_back(non_resource_results++);
|
||||
@ -713,6 +726,7 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
old_to_new_output_indices.push_back(-1);
|
||||
then_branch.front().getTerminator()->eraseOperand(non_resource_results);
|
||||
else_branch.front().getTerminator()->eraseOperand(non_resource_results);
|
||||
output_removed = true;
|
||||
}
|
||||
|
||||
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> then_use_info;
|
||||
@ -724,7 +738,7 @@ LogicalResult HanldeIfOP(TF::IfOp if_op, FuncOp then_branch,
|
||||
// A resource is considered used as long as it is used in either branch.
|
||||
auto resource_arg_uses =
|
||||
MergeArgResourceUseInfo(then_use_info, else_use_info);
|
||||
if (resource_arg_uses.empty()) return success();
|
||||
if (resource_arg_uses.empty() && !output_removed) return success();
|
||||
// Remove unused resources in functions.
|
||||
llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
|
||||
RemoveUnusedResourceArgumentsAndForwardedRetvals(
|
||||
@ -847,9 +861,8 @@ LogicalResult HandlePartitionedCallOpCallee(
|
||||
}
|
||||
auto aliasing_arg = retval.dyn_cast<BlockArgument>();
|
||||
if (!aliasing_arg) {
|
||||
return callee.emitOpError(
|
||||
"Unsupported function call: resource return value does not alias an "
|
||||
"input.");
|
||||
return callee.emitOpError("unsupported function call: ")
|
||||
<< "resource return value does not alias an input.";
|
||||
}
|
||||
result->old_outputs_aliasing_old_inputs[entry.index()] =
|
||||
aliasing_arg.getArgNumber();
|
||||
@ -982,6 +995,8 @@ LogicalResult HoistForFunctionalControlFlow(
|
||||
Block* block, ModuleOp module,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallLiftingInfo>*
|
||||
lifted_partitioned_call_callees) {
|
||||
// Remove identity nodes to avoid aliasing.
|
||||
RemoveIdentity(block);
|
||||
for (Operation& op : llvm::make_early_inc_range(*block)) {
|
||||
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
|
||||
@ -1002,11 +1017,11 @@ LogicalResult HoistForFunctionalControlFlow(
|
||||
lifted_partitioned_call_callees);
|
||||
HoistForFunctionalControlFlow(&else_branch.front(), module,
|
||||
lifted_partitioned_call_callees);
|
||||
if (failed(HanldeIfOP(if_op, then_branch, else_branch))) return failure();
|
||||
if (failed(HandleIfOP(if_op, then_branch, else_branch))) return failure();
|
||||
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!call_op.f().isa<FlatSymbolRefAttr>()) {
|
||||
return call_op.emitError(
|
||||
"Resource lifting does not support call with nested references.");
|
||||
return call_op.emitOpError(
|
||||
"resource lifting does not support call with nested references.");
|
||||
}
|
||||
auto callee = llvm::cast<FuncOp>(
|
||||
module.lookupSymbol(call_op.f().getRootReference()));
|
||||
@ -1024,6 +1039,24 @@ LogicalResult HoistForFunctionalControlFlow(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove unused local variables.
|
||||
ForwardStoreToLoad(block);
|
||||
llvm::SmallVector<TF::MlirLocalVarOp, 8> local_vars;
|
||||
for (Operation& op : *block) {
|
||||
if (auto local_var = llvm::dyn_cast<TF::MlirLocalVarOp>(&op)) {
|
||||
local_vars.push_back(local_var);
|
||||
}
|
||||
}
|
||||
for (auto local_var : local_vars) {
|
||||
if (llvm::all_of(local_var.resource().getUsers(),
|
||||
[](const Operation* user) {
|
||||
return llvm::isa<TF::AssignVariableOp>(user);
|
||||
})) {
|
||||
for (auto user : local_var.resource().getUsers()) user->erase();
|
||||
local_var.erase();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,742 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "mlir/IR/SymbolTable.h" // TF:llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
// A pass that converts stack operations to tensor operations and read/assign
|
||||
// ops on local variables. A later resource lifting pass can further remove the
|
||||
// local variables.
|
||||
//
|
||||
// This pass requires that the full shape of the stack can be inferred: 1) the
|
||||
// maximum size needs to be a constant and 2) a push op can be found with a
|
||||
// known shape, and all push ops need to have the same shape.
|
||||
//
|
||||
// A stack creation op "tf.StackV2" will be turned in to two zero-initialized
|
||||
// variables, for the buffer and current size. Each push will be turned into
|
||||
// %old_val = "tf.ReadVariableOp"(%buffer)
|
||||
// %old_size = "tf.ReadVariableOp"(%size)
|
||||
// %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
|
||||
// %new_val = "tf.XlaDynamicUpdateSlice"(%old_val, %push_val, %offsets)
|
||||
// "tf.AssignVariableOp"(%buffer, %new_val)
|
||||
// %new_size = "tf.AddV2"(%old_size, %const1)
|
||||
// "tf.AssignVariableOp"(%size, %new_size)
|
||||
//
|
||||
// and each pop will be turned into
|
||||
//
|
||||
// %old_val = "tf.ReadVariableOp"(%buffer)
|
||||
// %old_size = "tf.ReadVariableOp"(%size)
|
||||
// %new_size = "tf.Sub"(%old_size, %const1)
|
||||
// %offsets = "tf.ConcatV2"(%old_size, %other_dims_0s, %const0)
|
||||
// %slice = "tf.Slice"(%old_val, %offsets, %slice_size_const)
|
||||
// %pop_result = "tf.Reshape"(%slice, %elem_size_const)
|
||||
// "tf.AssignVariableOp"(%size, %new_size)
|
||||
//
|
||||
// The pass also works across control flow and functional calls.
|
||||
struct StackOpsDecompositionPass
|
||||
: public ModulePass<StackOpsDecompositionPass> {
|
||||
void runOnModule() override;
|
||||
};
|
||||
|
||||
// Creates a ReadVariableOp on a local variable.
|
||||
Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) {
|
||||
return builder
|
||||
.create<TF::ReadVariableOp>(
|
||||
loc,
|
||||
ArrayRef<Type>{getElementTypeOrSelf(local_var.getType())
|
||||
.cast<TF::ResourceType>()
|
||||
.getSubtypes()[0]},
|
||||
ArrayRef<Value>{local_var}, ArrayRef<NamedAttribute>{})
|
||||
.value();
|
||||
}
|
||||
|
||||
// Creates an AssignVariableOp on a local variable.
|
||||
TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value,
|
||||
OpBuilder builder, Location loc) {
|
||||
return builder.create<TF::AssignVariableOp>(loc, ArrayRef<Type>{},
|
||||
ArrayRef<Value>{local_var, value},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
}
|
||||
|
||||
// Creates an i32 scalar tf.Const.
|
||||
TF::ConstOp CreateScalarConst(int value, OpBuilder builder, Location loc) {
|
||||
tensorflow::Tensor scalar_tensor(tensorflow::DT_INT32, {});
|
||||
scalar_tensor.scalar<tensorflow::int32>()() = value;
|
||||
return builder.create<TF::ConstOp>(
|
||||
loc, tensorflow::ConvertTensor(scalar_tensor, &builder).ValueOrDie());
|
||||
}
|
||||
|
||||
// Creates an i32 vector tf.Const.
|
||||
TF::ConstOp GetR1Const(ArrayRef<int64_t> r1, OpBuilder builder, Location loc) {
|
||||
tensorflow::Tensor shape_tensor(tensorflow::DT_INT32,
|
||||
{static_cast<int64_t>(r1.size())});
|
||||
for (int i = 0; i < r1.size(); ++i) {
|
||||
shape_tensor.vec<tensorflow::int32>()(i) = r1[i];
|
||||
}
|
||||
return builder.create<TF::ConstOp>(
|
||||
loc, tensorflow::ConvertTensor(shape_tensor, &builder).ValueOrDie());
|
||||
}
|
||||
|
||||
// Creates a rank-1 op that represents the offsets of the stack element in the
|
||||
// stack buffer.
|
||||
Value GetIndicesForStackElement(Value index, Value stack_value,
|
||||
OpBuilder builder, Location loc) {
|
||||
auto stack_type = stack_value.getType().cast<RankedTensorType>();
|
||||
if (stack_type.getShape().size() == 1) return index;
|
||||
llvm::SmallVector<int64_t, 8> zeros(stack_type.getShape().size() - 1, 0);
|
||||
auto zeros_tensor = GetR1Const(zeros, builder, loc);
|
||||
return builder.create<TF::ConcatV2Op>(
|
||||
loc,
|
||||
ArrayRef<Type>{RankedTensorType::get(
|
||||
{static_cast<int64_t>(stack_type.getShape().size())},
|
||||
getElementTypeOrSelf(index.getType()))},
|
||||
ArrayRef<Value>{index, zeros_tensor, CreateScalarConst(0, builder, loc)},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
}
|
||||
|
||||
// Returns the type of the local variable for the stack size. It is a
|
||||
// tensor<1xi32>, and we use R1 instead of a scalar because it is easier to
|
||||
// concat it with other offsets.
|
||||
Type GetSizeVarType(OpBuilder builder) {
|
||||
auto size_type = RankedTensorType::get({1}, builder.getIntegerType(32));
|
||||
return RankedTensorType::get(
|
||||
{}, TF::ResourceType::get(ArrayRef<TensorType>{size_type},
|
||||
builder.getContext()));
|
||||
}
|
||||
|
||||
// Creates the buffer and size local variables for a stack.
|
||||
std::pair<Value, Value> CreateVariablesForStack(TensorType stack_tensor_type,
|
||||
TF::StackV2Op stack) {
|
||||
OpBuilder builder(stack);
|
||||
auto size_var_type = GetSizeVarType(builder);
|
||||
auto var_type = RankedTensorType::get(
|
||||
{}, TF::ResourceType::get(ArrayRef<TensorType>{stack_tensor_type},
|
||||
stack.getContext()));
|
||||
auto local_var = builder.create<TF::MlirLocalVarOp>(
|
||||
stack.getLoc(), ArrayRef<Type>{var_type}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
auto local_size_var = builder.create<TF::MlirLocalVarOp>(
|
||||
stack.getLoc(), ArrayRef<Type>{size_var_type}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
|
||||
// Zero-initialize the local vars.
|
||||
WriteLocalVariable(local_size_var, GetR1Const({0LL}, builder, stack.getLoc()),
|
||||
builder, stack.getLoc());
|
||||
auto zero = CreateScalarConst(0, builder, stack.getLoc()).output();
|
||||
if (getElementTypeOrSelf(zero.getType()) !=
|
||||
stack_tensor_type.getElementType()) {
|
||||
zero = builder.create<TF::CastOp>(
|
||||
stack.getLoc(),
|
||||
ArrayRef<Type>{
|
||||
RankedTensorType::get({}, stack_tensor_type.getElementType())},
|
||||
ArrayRef<Value>{zero}, ArrayRef<NamedAttribute>{});
|
||||
}
|
||||
auto broadcast = builder.create<TF::BroadcastToOp>(
|
||||
stack.getLoc(), ArrayRef<Type>{stack_tensor_type},
|
||||
ArrayRef<Value>{zero, GetR1Const(stack_tensor_type.getShape(), builder,
|
||||
stack.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
WriteLocalVariable(local_var, broadcast, builder, stack.getLoc());
|
||||
return {local_var, local_size_var};
|
||||
}
|
||||
|
||||
// Tries to infer the stack element type with full shape based on its uses.
|
||||
llvm::Optional<RankedTensorType> GetStackElementType(Value stack,
|
||||
ModuleOp module) {
|
||||
for (auto& use : stack.getUses()) {
|
||||
if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(use.getOwner())) {
|
||||
auto elem_type = push.elem().getType().dyn_cast<RankedTensorType>();
|
||||
if (elem_type && elem_type.hasStaticShape()) {
|
||||
return elem_type;
|
||||
}
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(use.getOwner())) {
|
||||
auto body = module.lookupSymbol<FuncOp>(while_op.body());
|
||||
assert(body);
|
||||
auto type_from_body =
|
||||
GetStackElementType(body.getArgument(use.getOperandNumber()), module);
|
||||
if (type_from_body.hasValue()) return type_from_body;
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(use.getOwner())) {
|
||||
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
|
||||
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
|
||||
assert(then_branch && else_branch);
|
||||
auto type_from_then = GetStackElementType(
|
||||
then_branch.getArgument(use.getOperandNumber() - 1), module);
|
||||
if (type_from_then.hasValue()) return type_from_then;
|
||||
auto type_from_else = GetStackElementType(
|
||||
else_branch.getArgument(use.getOperandNumber() - 1), module);
|
||||
if (type_from_else.hasValue()) return type_from_else;
|
||||
} else if (auto pcall =
|
||||
llvm::dyn_cast<TF::PartitionedCallOp>(use.getOwner())) {
|
||||
if (!pcall.f().isa<FlatSymbolRefAttr>()) continue;
|
||||
auto callee = module.lookupSymbol<FuncOp>(pcall.f().getRootReference());
|
||||
assert(callee);
|
||||
auto type_from_callee = GetStackElementType(
|
||||
callee.getArgument(use.getOperandNumber()), module);
|
||||
if (type_from_callee.hasValue()) return type_from_callee;
|
||||
} else if (auto spcall = llvm::dyn_cast<TF::StatefulPartitionedCallOp>(
|
||||
use.getOwner())) {
|
||||
auto callee = module.lookupSymbol<FuncOp>(spcall.f());
|
||||
assert(callee);
|
||||
auto type_from_callee = GetStackElementType(
|
||||
callee.getArgument(use.getOperandNumber()), module);
|
||||
if (type_from_callee.hasValue()) return type_from_callee;
|
||||
} else if (llvm::isa<TF::IdentityOp>(use.getOwner()) ||
|
||||
llvm::isa<TF::IdentityNOp>(use.getOwner())) {
|
||||
auto type_from_alias = GetStackElementType(
|
||||
use.getOwner()->getResult(use.getOperandNumber()), module);
|
||||
if (type_from_alias.hasValue()) return type_from_alias;
|
||||
}
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Returns the aliasing argument number of a fucntion return value if it simply
|
||||
// forwards the argument. Otherwise, returns -1.
|
||||
int64_t FindAliasedInput(FuncOp func, int64_t return_index) {
|
||||
Value return_val = func.front().getTerminator()->getOperand(return_index);
|
||||
auto maybe_arg = return_val.dyn_cast<BlockArgument>();
|
||||
if (!maybe_arg) return -1;
|
||||
return maybe_arg.getArgNumber();
|
||||
}
|
||||
|
||||
// Changes the function signature that has stacks in the arguments. A stack
|
||||
// argument will be turned into a variable type if arg_to_stack_type returns
|
||||
// such a type, and a new argument will be added to the end of the argument
|
||||
// list for the size variable.
|
||||
//
|
||||
// If stack_var_to_size_var is not nullptr, it will be used to store the
|
||||
// mapping from the stack-variable argument to the size-variable argument.
|
||||
//
|
||||
// If handle_new_size_vars is provided, it will be invoked on the list of new
|
||||
// size variables before finally changing the function type.
|
||||
void ModifyFunctionSignature(
|
||||
FuncOp func, llvm::SmallDenseMap<Value, Value>* stack_var_to_size_var,
|
||||
llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_stack_type,
|
||||
llvm::function_ref<void(ArrayRef<BlockArgument>)> handle_new_size_vars =
|
||||
nullptr) {
|
||||
auto new_input_types = llvm::to_vector<8>(func.getType().getInputs());
|
||||
auto size_var_type = GetSizeVarType(OpBuilder(func));
|
||||
int64_t original_arg_count = new_input_types.size();
|
||||
for (int64_t i = 0; i < original_arg_count; ++i) {
|
||||
auto stack_type = arg_to_stack_type(i);
|
||||
if (!stack_type.hasValue()) continue;
|
||||
func.getArgument(i).setType(*stack_type);
|
||||
new_input_types[i] = *stack_type;
|
||||
auto size_arg = func.front().addArgument(size_var_type);
|
||||
new_input_types.push_back(size_arg.getType());
|
||||
if (stack_var_to_size_var) {
|
||||
(*stack_var_to_size_var)[func.getArgument(i)] = size_arg;
|
||||
}
|
||||
}
|
||||
if (handle_new_size_vars) {
|
||||
handle_new_size_vars(func.getArguments().drop_front(original_arg_count));
|
||||
}
|
||||
func.setType(FunctionType::get(
|
||||
new_input_types,
|
||||
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
|
||||
func.getContext()));
|
||||
}
|
||||
|
||||
// Contains cached information for decomposed callee functions for (stateful)
|
||||
// partitioned call ops.
|
||||
struct PartitionedCallStackOpsInfo {
|
||||
bool signature_change;
|
||||
FuncOp decomposed_callee;
|
||||
llvm::SmallDenseMap<int64_t, int64_t> stack_var_arg_to_size_arg;
|
||||
};
|
||||
|
||||
LogicalResult DecomposeStackOpsInternal(
|
||||
Block*, ModuleOp, llvm::SmallDenseMap<Value, Value>*,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*);
|
||||
|
||||
// Handles stack usage by a tf.While. It will convert the body and conditional
|
||||
// function signatures, and performs stack ops decomposition on them.
|
||||
LogicalResult HandleWhileOp(
|
||||
TF::WhileOp while_op, ModuleOp module,
|
||||
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto body = module.lookupSymbol<FuncOp>(while_op.body());
|
||||
llvm::SmallDenseMap<Value, Value> body_map;
|
||||
auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
|
||||
auto it = data_var_to_size_var.find(while_op.getOperand(index));
|
||||
if (it == data_var_to_size_var.end()) return llvm::None;
|
||||
return it->getFirst().getType();
|
||||
};
|
||||
auto add_size_vars_to_return = [&](ArrayRef<BlockArgument> new_args) {
|
||||
if (new_args.empty()) return;
|
||||
auto body_ret = body.front().getTerminator();
|
||||
auto new_body_returns = llvm::to_vector<8>(body_ret->getOperands());
|
||||
for (auto arg : new_args) new_body_returns.push_back(arg);
|
||||
OpBuilder(body_ret).create<ReturnOp>(body_ret->getLoc(), new_body_returns);
|
||||
body_ret->erase();
|
||||
};
|
||||
// Handle body.
|
||||
ModifyFunctionSignature(body, &body_map, find_arg_stack_type,
|
||||
add_size_vars_to_return);
|
||||
const bool signature_change = !body_map.empty();
|
||||
if (failed(DecomposeStackOpsInternal(&body.front(), module, &body_map,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
// Cond should not change stacks in the arguments, so use an empty map.
|
||||
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
|
||||
ModifyFunctionSignature(cond, nullptr, find_arg_stack_type);
|
||||
llvm::SmallDenseMap<Value, Value> empty_map;
|
||||
if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
if (!signature_change) return success();
|
||||
// Create the new while op.
|
||||
auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
|
||||
auto new_output_shapes =
|
||||
llvm::to_vector<8>(while_op.output_shapes().getValue());
|
||||
OpBuilder builder(while_op);
|
||||
assert(while_op.getNumOperands() == while_op.getNumResults());
|
||||
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
|
||||
auto it = data_var_to_size_var.find(while_op.getOperand(i));
|
||||
if (it == data_var_to_size_var.end()) continue;
|
||||
new_while_operands.push_back(it->getSecond());
|
||||
if (!new_output_shapes.empty()) {
|
||||
// Size is a scalar shape.
|
||||
tensorflow::TensorShapeProto shape_proto;
|
||||
new_output_shapes.push_back(builder.getStringAttr(
|
||||
tensorflow::mangling_util::MangleShape(shape_proto)));
|
||||
}
|
||||
}
|
||||
auto new_while =
|
||||
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
|
||||
new_while_operands, while_op.getAttrs());
|
||||
new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
|
||||
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
|
||||
if (!getElementTypeOrSelf(while_op.getOperand(i).getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
int64_t aliased_input = FindAliasedInput(body, i);
|
||||
if (aliased_input == i) {
|
||||
// Replace aliased stack output uses with input.
|
||||
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
|
||||
}
|
||||
}
|
||||
while_op.replaceAllUsesWith(
|
||||
new_while.getResults().take_front(while_op.getNumResults()));
|
||||
while_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Handles stack usage by a tf.If. It will convert the branch function
|
||||
// signatures, and performs stack ops decomposition on them.
|
||||
LogicalResult HandleIfOp(
|
||||
TF::IfOp if_op, ModuleOp module,
|
||||
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
|
||||
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
|
||||
llvm::SmallDenseMap<Value, Value> then_map;
|
||||
llvm::SmallDenseMap<Value, Value> else_map;
|
||||
|
||||
auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
|
||||
auto it = data_var_to_size_var.find(if_op.getOperand(index + 1));
|
||||
if (it == data_var_to_size_var.end()) return llvm::None;
|
||||
return it->getFirst().getType();
|
||||
};
|
||||
ModifyFunctionSignature(then_branch, &then_map, find_arg_stack_type);
|
||||
ModifyFunctionSignature(else_branch, &else_map, find_arg_stack_type);
|
||||
const bool signature_change = !then_map.empty() || !else_map.empty();
|
||||
if (failed(DecomposeStackOpsInternal(&then_branch.front(), module, &then_map,
|
||||
decomposed_partitioned_call_callees)) ||
|
||||
failed(DecomposeStackOpsInternal(&else_branch.front(), module, &else_map,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
if (!signature_change) return success();
|
||||
auto new_if_operands = llvm::to_vector<8>(if_op.getOperands());
|
||||
for (auto operand : if_op.getOperands()) {
|
||||
auto it = data_var_to_size_var.find(operand);
|
||||
if (it == data_var_to_size_var.end()) continue;
|
||||
new_if_operands.push_back(it->getSecond());
|
||||
}
|
||||
auto new_if = OpBuilder(if_op).create<TF::IfOp>(
|
||||
if_op.getLoc(), then_branch.getType().getResults(), new_if_operands,
|
||||
if_op.getAttrs());
|
||||
for (auto result : if_op.getResults()) {
|
||||
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
int64_t then_aliased_input =
|
||||
FindAliasedInput(then_branch, result.getResultNumber());
|
||||
int64_t else_aliased_input =
|
||||
FindAliasedInput(else_branch, result.getResultNumber());
|
||||
if (then_aliased_input >= 0 && then_aliased_input == else_aliased_input) {
|
||||
// Replace aliased stack output uses with input.
|
||||
result.replaceAllUsesWith(if_op.getOperand(then_aliased_input + 1));
|
||||
}
|
||||
}
|
||||
if_op.replaceAllUsesWith(new_if);
|
||||
if_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Handles stack usage by a tf.StatefulPartitionedCall or a tf.PartitionedCall.
|
||||
// It will first check if the callee was previously handled, and try to reuse
|
||||
// that result if so. Otherwise, it will clone and convert the callee function,
|
||||
// and performs stack ops decomposition on it.
|
||||
template <typename CallOp>
|
||||
LogicalResult HandlePartitionedCallOp(
|
||||
CallOp call, FuncOp callee, ModuleOp module,
|
||||
const llvm::SmallDenseMap<Value, Value>& data_var_to_size_var,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
|
||||
callee, PartitionedCallStackOpsInfo());
|
||||
auto& info = emplace_res.first->getSecond();
|
||||
// Recreate the call op with info.
|
||||
auto recreate_caller = [&] {
|
||||
auto new_operands = llvm::to_vector<8>(call.getOperands());
|
||||
for (int64_t i = 0; i < call.getNumOperands(); ++i) {
|
||||
auto arg_it = info.stack_var_arg_to_size_arg.find(i);
|
||||
if (arg_it == info.stack_var_arg_to_size_arg.end()) continue;
|
||||
auto it = data_var_to_size_var.find(call.getOperand(i));
|
||||
if (it == data_var_to_size_var.end()) {
|
||||
call.emitOpError("Unknown stack.");
|
||||
return failure();
|
||||
}
|
||||
assert(arg_it->second == new_operands.size());
|
||||
new_operands.push_back(it->getSecond());
|
||||
}
|
||||
OpBuilder builder(call);
|
||||
auto new_call = builder.create<CallOp>(
|
||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||
new_operands, call.getAttrs());
|
||||
new_call.setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||
for (int64_t i = 0; i < call.getNumResults(); ++i) {
|
||||
auto result = call.getResult(i);
|
||||
if (!getElementTypeOrSelf(result.getType())
|
||||
.template isa<TF::ResourceType>()) {
|
||||
continue;
|
||||
}
|
||||
int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i);
|
||||
if (aliased_input >= 0) {
|
||||
// Replace aliased stack output uses with input.
|
||||
result.replaceAllUsesWith(call.getOperand(aliased_input));
|
||||
}
|
||||
}
|
||||
call.replaceAllUsesWith(new_call);
|
||||
call.erase();
|
||||
return success();
|
||||
};
|
||||
if (!emplace_res.second) {
|
||||
// This callee was handled before.
|
||||
if (!info.signature_change) return success();
|
||||
return recreate_caller();
|
||||
}
|
||||
llvm::SmallDenseMap<Value, Value> callee_map;
|
||||
auto callee_clone = callee.clone();
|
||||
auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional<Type> {
|
||||
auto it = data_var_to_size_var.find(call.getOperand(index));
|
||||
if (it == data_var_to_size_var.end()) return llvm::None;
|
||||
return it->getFirst().getType();
|
||||
};
|
||||
ModifyFunctionSignature(callee_clone, &callee_map, find_arg_stack_type);
|
||||
if (callee_map.empty()) {
|
||||
// Signature is not modified. We do not need the clone.
|
||||
info.signature_change = false;
|
||||
callee_clone.erase();
|
||||
} else {
|
||||
info.signature_change = true;
|
||||
info.decomposed_callee = callee_clone;
|
||||
for (auto& entry : callee_map) {
|
||||
info.stack_var_arg_to_size_arg
|
||||
[entry.getFirst().cast<BlockArgument>().getArgNumber()] =
|
||||
entry.getSecond().cast<BlockArgument>().getArgNumber();
|
||||
}
|
||||
// Add the clone with a new name.
|
||||
auto name_base = llvm::join(
|
||||
std::vector<std::string>{callee.getName().str(), "stack_decomposed"},
|
||||
"_");
|
||||
auto name = name_base;
|
||||
{
|
||||
int64_t counter = 0;
|
||||
while (module.lookupSymbol(name)) {
|
||||
name = llvm::formatv("{0}_{1}", name_base, counter++).str();
|
||||
}
|
||||
}
|
||||
callee_clone.setName(name);
|
||||
SymbolTable(module).insert(callee_clone);
|
||||
callee = callee_clone;
|
||||
}
|
||||
if (failed(DecomposeStackOpsInternal(&callee.front(), module, &callee_map,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
if (info.signature_change) return recreate_caller();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleStackV2Op(
|
||||
TF::StackV2Op stack, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
|
||||
// Create a buffer variable and a size variable to replace the stack.
|
||||
auto elem_type = GetStackElementType(stack.handle(), module);
|
||||
if (!elem_type.hasValue()) {
|
||||
return stack.emitOpError("cannot infer element shape of stack.");
|
||||
}
|
||||
auto size_op = stack.max_size().getDefiningOp();
|
||||
if (!size_op || !llvm::isa<TF::ConstOp>(size_op)) {
|
||||
return stack.emitOpError("max size of stack is not a constant.");
|
||||
}
|
||||
int64_t max_size =
|
||||
(*llvm::cast<TF::ConstOp>(size_op).value().getValues<APInt>().begin())
|
||||
.getSExtValue();
|
||||
llvm::SmallVector<int64_t, 8> stack_shape;
|
||||
stack_shape.push_back(max_size);
|
||||
for (int64_t dim : elem_type->getShape()) stack_shape.push_back(dim);
|
||||
auto stack_tensor_type =
|
||||
RankedTensorType::get(stack_shape, elem_type->getElementType());
|
||||
Value local_var;
|
||||
Value local_size_var;
|
||||
std::tie(local_var, local_size_var) =
|
||||
CreateVariablesForStack(stack_tensor_type, stack);
|
||||
stack.replaceAllUsesWith(local_var);
|
||||
(*data_var_to_size_var)[local_var] = local_size_var;
|
||||
stack.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleStackPushV2Op(
|
||||
TF::StackPushV2Op push,
|
||||
llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
|
||||
auto it = data_var_to_size_var->find(push.handle());
|
||||
if (it == data_var_to_size_var->end()) {
|
||||
return push.emitOpError("unknown stack.");
|
||||
}
|
||||
// Push output simply forward the input element.
|
||||
push.replaceAllUsesWith(push.elem());
|
||||
OpBuilder builder(push);
|
||||
// Read the current buffer and size.
|
||||
auto stack_val = ReadLocalVariable(push.handle(), builder, push.getLoc());
|
||||
auto index = ReadLocalVariable(it->getSecond(), builder, push.getLoc());
|
||||
auto stack_buffer_type = stack_val.getType().cast<RankedTensorType>();
|
||||
auto slice_shape = llvm::to_vector<8>(stack_buffer_type.getShape());
|
||||
slice_shape[0] = 1;
|
||||
// Caculate the updated buffer.
|
||||
auto update_slice = builder.create<TF::ReshapeOp>(
|
||||
push.getLoc(),
|
||||
ArrayRef<Type>{RankedTensorType::get(slice_shape,
|
||||
stack_buffer_type.getElementType())},
|
||||
ArrayRef<Value>{push.elem(),
|
||||
GetR1Const(slice_shape, builder, push.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
stack_val =
|
||||
builder
|
||||
.create<TF::XlaDynamicUpdateSliceOp>(
|
||||
push.getLoc(), ArrayRef<Type>{stack_val.getType()},
|
||||
ArrayRef<Value>{stack_val, update_slice,
|
||||
GetIndicesForStackElement(
|
||||
index, stack_val, builder, push.getLoc())},
|
||||
ArrayRef<NamedAttribute>{})
|
||||
.output();
|
||||
// Assign the new buffer and size.
|
||||
WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc());
|
||||
index = builder.create<TF::AddV2Op>(
|
||||
push.getLoc(), ArrayRef<Type>{index.getType()},
|
||||
ArrayRef<Value>{index, GetR1Const({1}, builder, push.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
WriteLocalVariable(it->getSecond(), index, builder, push.getLoc());
|
||||
push.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleStackPopV2Op(
|
||||
TF::StackPopV2Op pop,
|
||||
llvm::SmallDenseMap<Value, Value>* data_var_to_size_var) {
|
||||
auto it = data_var_to_size_var->find(pop.handle());
|
||||
if (it == data_var_to_size_var->end()) {
|
||||
return pop.emitOpError("unknown stack.");
|
||||
}
|
||||
OpBuilder builder(pop);
|
||||
// Read the current buffer and size.
|
||||
auto stack_val = ReadLocalVariable(pop.handle(), builder, pop.getLoc());
|
||||
auto size = ReadLocalVariable(it->getSecond(), builder, pop.getLoc());
|
||||
auto new_size = builder.create<TF::SubOp>(
|
||||
pop.getLoc(), ArrayRef<Type>{size.getType()},
|
||||
ArrayRef<Value>{size, GetR1Const({1}, builder, pop.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
auto stack_val_type = stack_val.getType().cast<RankedTensorType>();
|
||||
auto elem_type = RankedTensorType::get(stack_val_type.getShape().drop_front(),
|
||||
stack_val_type.getElementType());
|
||||
// Slice the buffer to get the element.
|
||||
llvm::SmallVector<int64_t, 8> slice_size;
|
||||
slice_size.push_back(1);
|
||||
for (int64_t dim : elem_type.getShape()) slice_size.push_back(dim);
|
||||
auto size_const = GetR1Const(slice_size, builder, pop.getLoc());
|
||||
auto slice_type =
|
||||
RankedTensorType::get(slice_size, stack_val_type.getElementType());
|
||||
auto slice = builder.create<TF::SliceOp>(
|
||||
pop.getLoc(), ArrayRef<Type>{slice_type},
|
||||
ArrayRef<Value>{
|
||||
stack_val,
|
||||
GetIndicesForStackElement(new_size, stack_val, builder, pop.getLoc()),
|
||||
size_const},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
auto pop_val = builder.create<TF::ReshapeOp>(
|
||||
pop.getLoc(), ArrayRef<Type>{elem_type},
|
||||
ArrayRef<Value>{slice,
|
||||
GetR1Const(elem_type.getShape(), builder, pop.getLoc())},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
pop.replaceAllUsesWith(pop_val.output());
|
||||
// Update the size.
|
||||
WriteLocalVariable(it->getSecond(), new_size, builder, pop.getLoc());
|
||||
pop.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Decomposes stack ops on a region and recursively decomposes called functions.
|
||||
// data_var_to_size_var: a mapping from stacks' buffer local variables to size
|
||||
// local variables.
|
||||
// decomposed_partitioned_call_callees: cache for partitioned call ops' callee
|
||||
// function handling.
|
||||
LogicalResult DecomposeStackOpsInternal(
|
||||
Block* block, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, Value>* data_var_to_size_var,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
// Removes identity nodes in the block. The device computation does not
|
||||
// need such nodes to carry information.
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
op.erase();
|
||||
} else if (auto stack = llvm::dyn_cast<TF::StackV2Op>(&op)) {
|
||||
if (failed(HandleStackV2Op(stack, module, data_var_to_size_var))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto push = llvm::dyn_cast<TF::StackPushV2Op>(&op)) {
|
||||
if (failed(HandleStackPushV2Op(push, data_var_to_size_var))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto pop = llvm::dyn_cast<TF::StackPopV2Op>(&op)) {
|
||||
if (failed(HandleStackPopV2Op(pop, data_var_to_size_var))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto close = llvm::dyn_cast<TF::StackCloseV2Op>(&op)) {
|
||||
data_var_to_size_var->erase(close.handle());
|
||||
close.erase();
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
if (failed(HandleWhileOp(while_op, module, *data_var_to_size_var,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
if (failed(HandleIfOp(if_op, module, *data_var_to_size_var,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!pcall.f().isa<FlatSymbolRefAttr>()) {
|
||||
return pcall.emitOpError(
|
||||
"Stack decomposition does not support call with nested "
|
||||
"references.");
|
||||
}
|
||||
if (failed(HandlePartitionedCallOp(
|
||||
pcall, module.lookupSymbol<FuncOp>(pcall.f().getRootReference()),
|
||||
module, *data_var_to_size_var,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto spcall =
|
||||
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
|
||||
if (failed(HandlePartitionedCallOp(
|
||||
spcall, module.lookupSymbol<FuncOp>(spcall.f()), module,
|
||||
*data_var_to_size_var, decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DecomposeStackOps(Block* block, ModuleOp module) {
|
||||
llvm::SmallDenseMap<Value, Value> data_var_to_size_var;
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallStackOpsInfo>
|
||||
decomposed_partitioned_call_callees;
|
||||
return DecomposeStackOpsInternal(block, module, &data_var_to_size_var,
|
||||
&decomposed_partitioned_call_callees);
|
||||
}
|
||||
|
||||
void StackOpsDecompositionPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
auto main = module.lookupSymbol<FuncOp>("main");
|
||||
if (!main) return;
|
||||
if (failed(DecomposeStackOps(&main.front(), module))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
static PassRegistration<StackOpsDecompositionPass> pass(
|
||||
"tf-stack-ops-decomposition",
|
||||
"Decompose stack operations into local variable operations. Needs static "
|
||||
"shapes.");
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace TF {
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateStackOpsDecompositionPass() {
|
||||
return std::make_unique<StackOpsDecompositionPass>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/STLExtras.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
@ -91,14 +92,14 @@ bool IsSupportedInputOp(Operation* op) {
|
||||
}
|
||||
|
||||
// Builds a TPUGetLayoutOp with the given compile op and input index.
|
||||
TF::TPUGetLayoutOp BuildGetLayout(Operation* compile, int64_t index,
|
||||
OpBuilder* builder) {
|
||||
builder->setInsertionPointAfter(compile);
|
||||
TF::TPUGetLayoutOp BuildGetLayout(tf_device::LaunchOp compile_launch,
|
||||
int64_t index, OpBuilder* builder) {
|
||||
builder->setInsertionPointAfter(compile_launch);
|
||||
return builder->create<TF::TPUGetLayoutOp>(
|
||||
compile->getLoc(),
|
||||
compile_launch.getLoc(),
|
||||
llvm::ArrayRef<Type>{
|
||||
RankedTensorType::get({-1}, builder->getIntegerType(64))},
|
||||
llvm::ArrayRef<Value>{compile->getResult(1)},
|
||||
llvm::ArrayRef<Value>{compile_launch.getResult(1)},
|
||||
llvm::ArrayRef<NamedAttribute>{
|
||||
builder->getNamedAttr("index", builder->getI64IntegerAttr(index)),
|
||||
builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
|
||||
@ -109,11 +110,12 @@ TF::TPUGetLayoutOp BuildGetLayout(Operation* compile, int64_t index,
|
||||
// ops after both get_layout and input, so we use the walk order to find which
|
||||
// one comes later.
|
||||
TF::TPUCopyWithLayoutOp BuildCopyWithLayout(
|
||||
TF::TPUExecuteOp execute, Operation* compile, TF::TPUGetLayoutOp get_layout,
|
||||
Value input, const llvm::SmallDenseMap<Operation*, int64_t>& walk_order,
|
||||
TF::TPUExecuteOp execute, tf_device::LaunchOp compile_launch,
|
||||
TF::TPUGetLayoutOp get_layout, Value input,
|
||||
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order,
|
||||
OpBuilder* builder) {
|
||||
auto input_op = input.getDefiningOp();
|
||||
int64_t compile_walk_order = walk_order.find(compile)->getSecond();
|
||||
int64_t compile_walk_order = walk_order.find(compile_launch)->getSecond();
|
||||
int64_t input_walk_order = walk_order.find(input_op)->getSecond();
|
||||
if (compile_walk_order > input_walk_order) {
|
||||
builder->setInsertionPointAfter(get_layout);
|
||||
@ -128,22 +130,22 @@ TF::TPUCopyWithLayoutOp BuildCopyWithLayout(
|
||||
|
||||
// Performs transformation for a non-replicated input.
|
||||
void HandleInput(Value input, int64_t index, TF::TPUExecuteOp execute,
|
||||
Operation* compile,
|
||||
tf_device::LaunchOp compile_launch,
|
||||
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
|
||||
OpBuilder builder(compile->getContext());
|
||||
auto get_layout = BuildGetLayout(compile, index, &builder);
|
||||
auto copy_with_layout = BuildCopyWithLayout(execute, compile, get_layout,
|
||||
input, walk_order, &builder);
|
||||
if (auto device = execute.getAttrOfType<StringAttr>(kDeviceAttr)) {
|
||||
copy_with_layout.setAttr(kDeviceAttr, device);
|
||||
}
|
||||
OpBuilder builder(compile_launch.getContext());
|
||||
auto get_layout = BuildGetLayout(compile_launch, index, &builder);
|
||||
auto copy_with_layout = BuildCopyWithLayout(
|
||||
execute, compile_launch, get_layout, input, walk_order, &builder);
|
||||
copy_with_layout.setAttr(
|
||||
kDeviceAttr,
|
||||
llvm::cast<tf_device::LaunchOp>(execute.getParentOp()).deviceAttr());
|
||||
execute.setOperand(index, copy_with_layout);
|
||||
}
|
||||
|
||||
// Performs transformation for replicated inputs. Returns true if this is a
|
||||
// supported case (thus transform happened).
|
||||
bool HandleReplicatedInputs(
|
||||
int64_t index, TF::TPUExecuteOp execute, Operation* compile,
|
||||
int64_t index, TF::TPUExecuteOp execute, tf_device::LaunchOp compile_launch,
|
||||
int64_t replicate_arg_index, tf_device::ReplicateOp replicate,
|
||||
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
|
||||
// We need to know the devices to copy to.
|
||||
@ -157,10 +159,11 @@ bool HandleReplicatedInputs(
|
||||
if (!input_op || !IsSupportedInputOp(input_op)) return false;
|
||||
}
|
||||
OpBuilder builder(execute.getContext());
|
||||
auto get_layout = BuildGetLayout(compile, index, &builder);
|
||||
auto get_layout = BuildGetLayout(compile_launch, index, &builder);
|
||||
for (auto entry : llvm::enumerate(inputs)) {
|
||||
auto copy_with_layout = BuildCopyWithLayout(
|
||||
execute, compile, get_layout, entry.value(), walk_order, &builder);
|
||||
auto copy_with_layout =
|
||||
BuildCopyWithLayout(execute, compile_launch, get_layout, entry.value(),
|
||||
walk_order, &builder);
|
||||
|
||||
// As model parallelism is not supported yet, assume that all ops are
|
||||
// placed at logical core 0.
|
||||
@ -179,7 +182,7 @@ bool HandleReplicatedInputs(
|
||||
|
||||
// Performs transformation on a pair of execute and compile ops. The compile
|
||||
// should not have other uses.
|
||||
void HandleExecute(TF::TPUExecuteOp execute, Operation* compile,
|
||||
void HandleExecute(TF::TPUExecuteOp execute, tf_device::LaunchOp compile_launch,
|
||||
const llvm::SmallDenseMap<Operation*, int64_t>& walk_order) {
|
||||
auto maybe_replicate = execute.getParentOfType<tf_device::ReplicateOp>();
|
||||
llvm::SmallVector<int64_t, 8> unrestricted_input_indices;
|
||||
@ -188,7 +191,7 @@ void HandleExecute(TF::TPUExecuteOp execute, Operation* compile,
|
||||
// For a block argument, consider transforms only when it is a replicated
|
||||
// input (defining ops will be outside the replicate node).
|
||||
if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
|
||||
!HandleReplicatedInputs(input.index(), execute, compile,
|
||||
!HandleReplicatedInputs(input.index(), execute, compile_launch,
|
||||
block_arg.getArgNumber(), maybe_replicate,
|
||||
walk_order)) {
|
||||
continue;
|
||||
@ -204,26 +207,28 @@ void HandleExecute(TF::TPUExecuteOp execute, Operation* compile,
|
||||
continue;
|
||||
}
|
||||
if (!IsSupportedInputOp(input_op)) continue;
|
||||
HandleInput(input.value(), input.index(), execute, compile, walk_order);
|
||||
HandleInput(input.value(), input.index(), execute, compile_launch,
|
||||
walk_order);
|
||||
}
|
||||
unrestricted_input_indices.push_back(input.index());
|
||||
}
|
||||
if (unrestricted_input_indices.empty()) return;
|
||||
|
||||
// Update the compilation metadata if we changed anything.
|
||||
auto metadata_attr = compile->getAttrOfType<StringAttr>("metadata");
|
||||
Operation& compile = compile_launch.GetBody().front();
|
||||
auto metadata_attr = compile.getAttrOfType<StringAttr>("metadata");
|
||||
assert(metadata_attr && "Missing compilation metadata");
|
||||
tensorflow::tpu::TPUCompileMetadataProto metadata;
|
||||
metadata.ParseFromString(std::string(metadata_attr.getValue()));
|
||||
for (int64_t input_index : unrestricted_input_indices) {
|
||||
metadata.mutable_args(input_index)->set_unrestricted_layout(true);
|
||||
}
|
||||
compile->setAttr("metadata", OpBuilder(compile).getStringAttr(
|
||||
metadata.SerializeAsString()));
|
||||
compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
|
||||
compile.getContext()));
|
||||
}
|
||||
|
||||
void TPUDynamicLayoutPass::runOnFunction() {
|
||||
llvm::SmallVector<std::pair<TF::TPUExecuteOp, Operation*>, 4>
|
||||
llvm::SmallVector<std::pair<TF::TPUExecuteOp, tf_device::LaunchOp>, 4>
|
||||
executes_and_compiles;
|
||||
llvm::SmallDenseMap<Operation*, int64_t> walk_order;
|
||||
int64_t next_walk_order = 0;
|
||||
@ -232,12 +237,17 @@ void TPUDynamicLayoutPass::runOnFunction() {
|
||||
// Detect tf._TPUCompileMlir -> tf.TPUExecute
|
||||
auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(op);
|
||||
if (!execute) return;
|
||||
auto execute_launch =
|
||||
llvm::dyn_cast_or_null<tf_device::LaunchOp>(execute.getParentOp());
|
||||
if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
|
||||
auto compile = execute.key().getDefiningOp();
|
||||
if (!compile || compile->getName().getStringRef() != "tf._TPUCompileMlir" ||
|
||||
!compile->getResult(1).hasOneUse()) {
|
||||
if (!compile || !compile->getResult(1).hasOneUse()) return;
|
||||
auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
|
||||
if (!compile_launch || !compile_launch.WrapsSingleOp() ||
|
||||
compile_launch.GetBody().front().getName().getStringRef() !=
|
||||
"tf._TPUCompileMlir")
|
||||
return;
|
||||
}
|
||||
executes_and_compiles.emplace_back(execute, compile);
|
||||
executes_and_compiles.emplace_back(execute, compile_launch);
|
||||
});
|
||||
for (auto execute_and_compile : executes_and_compiles) {
|
||||
HandleExecute(execute_and_compile.first, execute_and_compile.second,
|
||||
|
@ -121,13 +121,13 @@ bool OpAccessesResource(Operation* op) {
|
||||
// e.g., guaranteed by replication.
|
||||
// - `check_same_region` specifies whether the reads/assigns need to be in the
|
||||
// same region as `execute`. This is needed if `execute` is inside ReplicateOp.
|
||||
VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
bool check_device,
|
||||
bool check_same_region) {
|
||||
VariableAccessesForTPUExecute BuildVariableAccessInfo(
|
||||
tf_device::LaunchOp execute_launch, bool check_device,
|
||||
bool check_same_region) {
|
||||
VariableAccessesForTPUExecute infos;
|
||||
auto device_attr = execute->getAttr(kDeviceAttr);
|
||||
Attribute device_attr = execute_launch.deviceAttr();
|
||||
if (check_device && !device_attr) return infos;
|
||||
auto func = execute->getParentOfType<mlir::FuncOp>();
|
||||
auto func = execute_launch.getParentOfType<mlir::FuncOp>();
|
||||
|
||||
// Track the first read op found, which is used later to check if there are
|
||||
// assign ops between it and the TPUExecute op. We will exclude reads before
|
||||
@ -135,21 +135,23 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
// consider resource accesses in other islands since they ordering is enforced
|
||||
// by inter-island dependencies.
|
||||
Operation* first_read = nullptr;
|
||||
Operation& execute = execute_launch.GetBody().front();
|
||||
// Find inputs that are variable reads.
|
||||
for (auto operand : llvm::enumerate(execute->getOpOperands())) {
|
||||
for (auto operand : llvm::enumerate(execute.getOpOperands())) {
|
||||
infos.new_operand_values.push_back(operand.value().get());
|
||||
if (!operand.value().get().getDefiningOp()) continue;
|
||||
auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(
|
||||
operand.value().get().getDefiningOp());
|
||||
if (!read_op) continue;
|
||||
if (check_same_region &&
|
||||
read_op.getParentRegion() != execute->getParentRegion()) {
|
||||
read_op.getParentRegion() != execute_launch.getParentRegion()) {
|
||||
continue;
|
||||
}
|
||||
auto resource = read_op.resource();
|
||||
|
||||
if (check_device) {
|
||||
if (auto resource_op = resource.getDefiningOp()) {
|
||||
// TODO(lyandy): Wrap resource ops in tf_device.launch.
|
||||
if (auto* resource_op = resource.getDefiningOp()) {
|
||||
auto resource_attr = resource_op->getAttr(kDeviceAttr);
|
||||
// Check device matching for the node defining the resource.
|
||||
if (!resource_attr || resource_attr != device_attr) continue;
|
||||
@ -169,7 +171,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
if (!emplace_res.second) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Skipping execute that has multiple reads of a variable: "
|
||||
<< *execute << "\n");
|
||||
<< execute << "\n");
|
||||
infos.per_resource_info.shrink_and_clear();
|
||||
return infos;
|
||||
}
|
||||
@ -191,8 +193,9 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
// work fine for the reads/assigns created by resource lifting, since they are
|
||||
// placed close to the TPUExecute.
|
||||
Operation* last_may_modify_resource_access_before_execute = nullptr;
|
||||
for (Operation& op : llvm::reverse(llvm::make_range(
|
||||
std::next(first_read->getIterator()), execute->getIterator()))) {
|
||||
for (Operation& op : llvm::reverse(
|
||||
llvm::make_range(std::next(first_read->getIterator()),
|
||||
execute_launch.getOperation()->getIterator()))) {
|
||||
if (llvm::dyn_cast<TF::ReadVariableOp>(&op)) continue;
|
||||
if (!OpAccessesResource(&op)) continue;
|
||||
last_may_modify_resource_access_before_execute = &op;
|
||||
@ -209,7 +212,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
auto info_it = infos.per_resource_info.find(read.resource());
|
||||
if (info_it == infos.per_resource_info.end()) continue;
|
||||
int input_index = info_it->getSecond().execute_input_index;
|
||||
infos.new_operand_values[input_index] = execute->getOperand(input_index);
|
||||
infos.new_operand_values[input_index] = execute.getOperand(input_index);
|
||||
infos.per_resource_info.erase(info_it);
|
||||
}
|
||||
infos.resources_read.erase(
|
||||
@ -227,9 +230,12 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
// Find outputs that are variable assigns.
|
||||
Operation* last_assign = nullptr;
|
||||
llvm::SmallPtrSet<Operation*, 8> all_assigns;
|
||||
llvm::SmallVector<bool, 8> output_fused(execute->getNumResults(), false);
|
||||
for (int i = 0; i < execute->getNumResults(); ++i) {
|
||||
auto result = execute->getResult(i);
|
||||
llvm::SmallVector<bool, 8> output_fused(execute_launch.getNumResults(),
|
||||
false);
|
||||
for (int i = 0; i < execute_launch.getNumResults(); ++i) {
|
||||
// TODO(lyandy): Handle updates to resource writes by remapping to parent
|
||||
// launch result and checking if launch result is an AssignVariableOp.
|
||||
auto result = execute_launch.getResult(i);
|
||||
if (!result.hasOneUse()) continue;
|
||||
auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(*result.user_begin());
|
||||
if (!assign_op) continue;
|
||||
@ -240,7 +246,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
if (info.assign) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Skipping execute that has multiple assigns of a variable: "
|
||||
<< *execute << "\n");
|
||||
<< execute << "\n");
|
||||
infos.per_resource_info.shrink_and_clear();
|
||||
return infos;
|
||||
}
|
||||
@ -256,8 +262,9 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
// Check if there are other resource accesses after execute.
|
||||
Operation* first_unknown_resource_access_after_execute = nullptr;
|
||||
if (last_assign) {
|
||||
for (auto& op : llvm::make_range(std::next(execute->getIterator()),
|
||||
last_assign->getIterator())) {
|
||||
for (auto& op : llvm::make_range(
|
||||
std::next(execute_launch.getOperation()->getIterator()),
|
||||
last_assign->getIterator())) {
|
||||
if (all_assigns.count(&op) > 0) continue;
|
||||
if (!OpAccessesResource(&op)) continue;
|
||||
first_unknown_resource_access_after_execute = &op;
|
||||
@ -282,8 +289,8 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
|
||||
// Populate infos.old_to_new_output_mapping.
|
||||
int new_output_index = 0;
|
||||
infos.old_to_new_output_mapping.resize(execute->getNumResults());
|
||||
for (int i = 0; i < execute->getNumResults(); ++i) {
|
||||
infos.old_to_new_output_mapping.resize(execute_launch.getNumResults());
|
||||
for (int i = 0; i < execute_launch.getNumResults(); ++i) {
|
||||
if (output_fused[i]) {
|
||||
infos.old_to_new_output_mapping[i] = -1;
|
||||
} else {
|
||||
@ -295,19 +302,20 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(Operation* execute,
|
||||
}
|
||||
|
||||
// Merges the variable accesses into one TPUExecute op.
|
||||
void MergeForOneTPUExecute(Operation* execute, bool check_device,
|
||||
bool check_same_region, OpBuilder* builder) {
|
||||
void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
|
||||
bool check_device, bool check_same_region,
|
||||
OpBuilder* builder) {
|
||||
auto infos =
|
||||
BuildVariableAccessInfo(execute, check_device, check_same_region);
|
||||
BuildVariableAccessInfo(execute_launch, check_device, check_same_region);
|
||||
if (infos.per_resource_info.empty()) {
|
||||
return;
|
||||
}
|
||||
// Start creating the new TPUExecuteAndUpdateVariables op.
|
||||
builder->setInsertionPoint(execute);
|
||||
builder->setInsertionPoint(execute_launch);
|
||||
// Output types. Skip the original outputs for fused assigns.
|
||||
llvm::SmallVector<Type, 8> new_output_types;
|
||||
int old_output_index = 0;
|
||||
for (const auto& type : execute->getResultTypes()) {
|
||||
for (const auto& type : execute_launch.getResultTypes()) {
|
||||
if (infos.old_to_new_output_mapping[old_output_index] >= 0) {
|
||||
new_output_types.push_back(type);
|
||||
}
|
||||
@ -322,7 +330,7 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
|
||||
device_var_updates_indices.push_back(info.execute_output_index);
|
||||
}
|
||||
auto merged_execute = builder->create<TF::TPUExecuteAndUpdateVariablesOp>(
|
||||
execute->getLoc(), new_output_types, infos.new_operand_values,
|
||||
execute_launch.getLoc(), new_output_types, infos.new_operand_values,
|
||||
llvm::ArrayRef<NamedAttribute>{
|
||||
builder->getNamedAttr(
|
||||
"device_var_reads_indices",
|
||||
@ -331,15 +339,24 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
|
||||
"device_var_updates_indices",
|
||||
builder->getI64ArrayAttr(device_var_updates_indices))});
|
||||
|
||||
if (auto device = execute->getAttr(kDeviceAttr)) {
|
||||
merged_execute.setAttr(kDeviceAttr, device);
|
||||
}
|
||||
// Wrap in launch for device assignment.
|
||||
auto merged_execute_launch = builder->create<tf_device::LaunchOp>(
|
||||
merged_execute.getLoc(), execute_launch.deviceAttr(),
|
||||
merged_execute.getResultTypes());
|
||||
merged_execute_launch.body().push_back(new Block);
|
||||
|
||||
builder->setInsertionPointToEnd(&merged_execute_launch.GetBody());
|
||||
builder->create<tf_device::ReturnOp>(merged_execute.getLoc(),
|
||||
merged_execute.getResults());
|
||||
|
||||
merged_execute.getOperation()->moveBefore(
|
||||
merged_execute_launch.GetBody().getTerminator());
|
||||
|
||||
// Replace the uses.
|
||||
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
|
||||
if (infos.old_to_new_output_mapping[i] < 0) continue;
|
||||
execute->getResult(i).replaceAllUsesWith(
|
||||
merged_execute.getResult(infos.old_to_new_output_mapping[i]));
|
||||
execute_launch.getResult(i).replaceAllUsesWith(
|
||||
merged_execute_launch.getResult(infos.old_to_new_output_mapping[i]));
|
||||
}
|
||||
// Remove the assign ops.
|
||||
for (const auto& entry : infos.per_resource_info) {
|
||||
@ -347,7 +364,7 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
|
||||
if (info.assign) info.assign->erase();
|
||||
}
|
||||
// Remove the original TPUExecute op.
|
||||
execute->erase();
|
||||
execute_launch.erase();
|
||||
// Remove the read ops if they have no more uses.
|
||||
for (const auto& entry : infos.per_resource_info) {
|
||||
const auto& info = entry.getSecond();
|
||||
@ -358,19 +375,22 @@ void MergeForOneTPUExecute(Operation* execute, bool check_device,
|
||||
void TPUMergeVariablesWithExecutePass::runOnFunction() {
|
||||
// Find all the executes first, since we will mutate the nodes around each
|
||||
// execute.
|
||||
llvm::SmallVector<Operation*, 8> executes;
|
||||
getFunction().walk([&](TF::TPUExecuteOp op) { executes.push_back(op); });
|
||||
llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
|
||||
getFunction().walk([&](tf_device::LaunchOp op) {
|
||||
if (op.WrapsSingleOp() && llvm::isa<TF::TPUExecuteOp>(op.GetBody().front()))
|
||||
execute_launches.push_back(op);
|
||||
});
|
||||
|
||||
for (auto execute : executes) {
|
||||
for (auto execute_launch : execute_launches) {
|
||||
OpBuilder builder(&getContext());
|
||||
const bool parent_is_replicate =
|
||||
llvm::isa<tf_device::ReplicateOp>(execute->getParentOp());
|
||||
llvm::isa<tf_device::ReplicateOp>(execute_launch.getParentOp());
|
||||
// If this is inside a tf_device::ReplicateOp, the variables are guaranteed
|
||||
// to be on the same device as the TPUExecute op. Skip device checking in
|
||||
// that case, but we need to check that we are only merging reads/assigns
|
||||
// that are also in this replicated region.
|
||||
MergeForOneTPUExecute(execute, !parent_is_replicate, parent_is_replicate,
|
||||
&builder);
|
||||
MergeForOneTPUExecute(execute_launch, !parent_is_replicate,
|
||||
parent_is_replicate, &builder);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -255,6 +255,26 @@ LogicalResult SetMetadataProtoFromLaunchFuncOp(
|
||||
return success();
|
||||
}
|
||||
|
||||
// Wraps single op in `tf_device.launch` for explicit device assignment.
|
||||
tf_device::LaunchOp WrapOpInLaunch(OpBuilder* builder, Location loc,
|
||||
Operation* op, llvm::StringRef device) {
|
||||
OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
|
||||
|
||||
auto launch = builder->create<tf_device::LaunchOp>(
|
||||
loc, builder->getStringAttr(device), op->getResultTypes());
|
||||
launch.body().push_back(new Block);
|
||||
|
||||
builder->setInsertionPointToEnd(&launch.GetBody());
|
||||
builder->create<tf_device::ReturnOp>(loc, op->getResults());
|
||||
|
||||
// Move op inside launch.
|
||||
op->moveBefore(launch.GetBody().getTerminator());
|
||||
|
||||
builder->restoreInsertionPoint(insert_point);
|
||||
|
||||
return launch;
|
||||
}
|
||||
|
||||
// Create a `tf._TPUCompileMlir` that contains a MLIR module that is
|
||||
// functionally equivalent to the function referenced by launch_func.
|
||||
Operation* BuildCompileOp(
|
||||
@ -319,9 +339,6 @@ Operation* BuildCompileOp(
|
||||
compile_op_state.addAttribute("mlir_module",
|
||||
builder->getStringAttr(txt_module));
|
||||
|
||||
compile_op_state.addAttribute(kDeviceAttr,
|
||||
builder->getStringAttr(compilation_device));
|
||||
|
||||
// Result #0 is a string indicating whether compilation is successful or not.
|
||||
compile_op_state.addTypes(
|
||||
RankedTensorType::get({}, builder->getType<TF::StringType>()));
|
||||
@ -330,7 +347,10 @@ Operation* BuildCompileOp(
|
||||
compile_op_state.addTypes(
|
||||
RankedTensorType::get({}, builder->getType<TF::StringType>()));
|
||||
|
||||
return builder->createOperation(compile_op_state);
|
||||
Operation* compile_op = builder->createOperation(compile_op_state);
|
||||
|
||||
return WrapOpInLaunch(builder, compile_op->getLoc(), compile_op,
|
||||
compilation_device);
|
||||
}
|
||||
|
||||
// Creates a `tf.TPUExecute` op that executes TPU program generated by
|
||||
@ -377,6 +397,7 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op,
|
||||
// For each logical core, create a region with TPUExecute op.
|
||||
for (int core_id = 0; core_id < num_logical_cores; ++core_id) {
|
||||
auto& region = parallel_execute_op.GetRegionBlockWithIndex(core_id);
|
||||
builder->setInsertionPointToEnd(®ion);
|
||||
|
||||
// Create Execute op.
|
||||
//
|
||||
@ -389,20 +410,9 @@ Operation* BuildParallelExecuteOp(int num_logical_cores, Operation* compile_op,
|
||||
//
|
||||
// TODO(b/149102679): Add device attribute to launch op once device
|
||||
// topology for multiple logical cores can be correctly parsed.
|
||||
builder->setInsertionPointToStart(®ion);
|
||||
auto region_launch_op = builder->create<tf_device::LaunchOp>(
|
||||
region.getParent()->getLoc(), builder->getStringAttr(""),
|
||||
launch_func.results().getTypes());
|
||||
region_launch_op.body().push_back(new Block);
|
||||
auto region_launch_op = WrapOpInLaunch(
|
||||
builder, region.getParent()->getLoc(), execute, /*device=*/"");
|
||||
|
||||
builder->setInsertionPointToEnd(®ion_launch_op.GetBody());
|
||||
builder->create<tf_device::ReturnOp>(region_launch_op.getLoc(),
|
||||
execute->getResults());
|
||||
|
||||
// Move execute inside the launch op.
|
||||
execute->moveBefore(®ion_launch_op.GetBody().back());
|
||||
|
||||
builder->setInsertionPointToEnd(®ion);
|
||||
builder->create<tf_device::ReturnOp>(region.getParent()->getLoc(),
|
||||
region_launch_op.getResults());
|
||||
}
|
||||
@ -423,13 +433,14 @@ void RemapOutputsOfParallelExecute(tf_device::LaunchFuncOp launch_func,
|
||||
std::get<0>(outputs).replaceAllUsesWith(std::get<1>(outputs));
|
||||
}
|
||||
|
||||
void AssignDevicesToReplicatedExecute(
|
||||
tf_device::LaunchOp AssignDevicesToReplicatedExecute(
|
||||
llvm::ArrayRef<llvm::SmallVector<std::string, 8>> execution_devices,
|
||||
tf_device::ReplicateOp replicate, Operation* execute_op,
|
||||
OpBuilder* builder) {
|
||||
// If computation is replicated, execution devices are assigned to the
|
||||
// replicate. Otherwise there is only one execution device and the device is
|
||||
// assigned to the execute op.
|
||||
std::string device;
|
||||
if (replicate) {
|
||||
// Model parallelism is not support for now. Therefore, assign all ops
|
||||
// in replicate op with virtual device alias specifying that ops will be
|
||||
@ -439,24 +450,27 @@ void AssignDevicesToReplicatedExecute(
|
||||
for (const auto& replica_execution_devices : execution_devices)
|
||||
replicate_execution_devices.push_back(replica_execution_devices.front());
|
||||
|
||||
device = tensorflow::GetDeviceAliasForLogicalCore(0);
|
||||
auto device_attr = builder->getNamedAttr(
|
||||
tensorflow::GetDeviceAliasForLogicalCore(0),
|
||||
builder->getStrArrayAttr(replicate_execution_devices));
|
||||
device, builder->getStrArrayAttr(replicate_execution_devices));
|
||||
replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attr));
|
||||
} else {
|
||||
execute_op->setAttr(
|
||||
kDeviceAttr, builder->getStringAttr(execution_devices.front().front()));
|
||||
device = execution_devices.front().front();
|
||||
}
|
||||
|
||||
return WrapOpInLaunch(builder, execute_op->getLoc(), execute_op, device);
|
||||
}
|
||||
|
||||
// Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
|
||||
// status of `compile_op` to check whether compilation is successful.
|
||||
void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
|
||||
llvm::StringRef compilation_device,
|
||||
OpBuilder* builder) {
|
||||
OperationState assert_op_state(compile_op->getLoc(),
|
||||
"tf.TPUCompileSucceededAssert");
|
||||
assert_op_state.addOperands(compile_op->getResult(0));
|
||||
builder->createOperation(assert_op_state);
|
||||
Operation* assert_op = builder->createOperation(assert_op_state);
|
||||
WrapOpInLaunch(builder, compile_op->getLoc(), assert_op, compilation_device);
|
||||
}
|
||||
|
||||
// Rewrites a `tf_device.launch_func` operation into a set of TPU Runtime
|
||||
@ -547,7 +561,7 @@ LogicalResult Rewrite(
|
||||
<< "error in fetching TPU compilation/execution devices: "
|
||||
<< status_or_tpu_device_assignment.status().error_message();
|
||||
|
||||
// Create compile op;
|
||||
// Create compile op.
|
||||
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
||||
builder->setInsertionPoint(launch_func);
|
||||
Operation* compile_op = BuildCompileOp(
|
||||
@ -564,24 +578,25 @@ LogicalResult Rewrite(
|
||||
for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
|
||||
compile_result_op.output().replaceAllUsesWith(compile_op->getResult(0));
|
||||
|
||||
BuildTPUCompileSucceededAssertOp(compile_op, builder);
|
||||
BuildTPUCompileSucceededAssertOp(
|
||||
compile_op, tpu_device_assignment.compilation_device, builder);
|
||||
|
||||
Operation* execute_op;
|
||||
if (num_cores_per_replica > 1) {
|
||||
// For model parallelism, tf_device.parallel_execute is used to express
|
||||
// concurrent device execution across multiple logical devices.
|
||||
execute_op = BuildParallelExecuteOp(num_cores_per_replica, compile_op,
|
||||
launch_func, builder);
|
||||
Operation* execute_op = BuildParallelExecuteOp(
|
||||
num_cores_per_replica, compile_op, launch_func, builder);
|
||||
|
||||
RemapOutputsOfParallelExecute(launch_func, execute_op);
|
||||
|
||||
// TODO(hongjunchoi): Correctly parse TPU topology and assign logical device
|
||||
// attributes to launch_op's within parallel_execute op.
|
||||
} else {
|
||||
execute_op = BuildExecuteOp(compile_op, launch_func, builder);
|
||||
AssignDevicesToReplicatedExecute(tpu_device_assignment.execution_devices,
|
||||
replicate, execute_op, builder);
|
||||
launch_func.replaceAllUsesWith(execute_op);
|
||||
Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder);
|
||||
tf_device::LaunchOp launch_op = AssignDevicesToReplicatedExecute(
|
||||
tpu_device_assignment.execution_devices, replicate, execute_op,
|
||||
builder);
|
||||
launch_func.replaceAllUsesWith(launch_op);
|
||||
}
|
||||
|
||||
launch_func.erase();
|
||||
|
@ -0,0 +1,201 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Block.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/sharding_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
namespace {
|
||||
|
||||
constexpr char kXlaShardingAttr[] = "_XlaSharding";
|
||||
constexpr char kInputShardingAttr[] = "input_sharding_configuration";
|
||||
constexpr char kOutputShardingAttr[] = "output_sharding_configuration";
|
||||
|
||||
struct TPUShardingIdentificationPass
|
||||
: public ModulePass<TPUShardingIdentificationPass> {
|
||||
void runOnModule() override;
|
||||
};
|
||||
|
||||
// XlaSharding op may be direct user of inputs but it may also be followed by
|
||||
// an Identity op and, in the case where bfloat16 type is used, Cast op may be
|
||||
// added right after the input. As so, parse the users of the operation to
|
||||
// access connected XlaSharding op.
|
||||
//
|
||||
// TODO(hongjunchoi): Consider explicitly checking op patterns to detect
|
||||
// sharded inputs.
|
||||
void GetAdjacentToXlaShardingOp(
|
||||
Operation* op, llvm::Optional<TF::XlaShardingOp>* sharding_op) {
|
||||
// TODO(hongjunchoi): Detect the case when sharding configuration is
|
||||
// ambiguous for a single input (i.e. multiple different XlaSharding ops
|
||||
// with different configuration policies are connected).
|
||||
if (sharding_op->hasValue()) return;
|
||||
|
||||
if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(op)) {
|
||||
sharding_op->emplace(sharding);
|
||||
return;
|
||||
}
|
||||
|
||||
if (llvm::isa<TF::IdentityOp>(op) || llvm::isa<TF::CastOp>(op)) {
|
||||
for (auto user : op->getUsers())
|
||||
GetAdjacentToXlaShardingOp(user, sharding_op);
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Optional<StringRef> ParseShardingAttribute(Operation* operation) {
|
||||
const auto& sharding_attr =
|
||||
operation->getAttrOfType<StringAttr>(kXlaShardingAttr);
|
||||
if (!sharding_attr) return llvm::Optional<StringRef>();
|
||||
return sharding_attr.getValue();
|
||||
}
|
||||
|
||||
// Parse XlaSharding op connected to input args. If Input to
|
||||
// tf_device.LaunchFunc op is of resource type, then XlaSharding op
|
||||
// will be connected to following ReadVariable op.
|
||||
//
|
||||
// TODO(hongjunchoi): Add logic to parse XlaSharding op inside a
|
||||
// Call op or if/while op.
|
||||
llvm::Optional<StringRef> ParseInputSharding(const FuncOp func,
|
||||
const int arg_index,
|
||||
const Value& arg) {
|
||||
llvm::Optional<TF::XlaShardingOp> parsed_sharding_op;
|
||||
for (auto user : arg.getUsers()) {
|
||||
if (parsed_sharding_op) continue;
|
||||
|
||||
GetAdjacentToXlaShardingOp(user, &parsed_sharding_op);
|
||||
if (parsed_sharding_op) continue;
|
||||
|
||||
if (llvm::isa<TF::ReadVariableOp>(user))
|
||||
for (auto read_variable_user : user->getUsers())
|
||||
GetAdjacentToXlaShardingOp(read_variable_user, &parsed_sharding_op);
|
||||
}
|
||||
|
||||
if (!parsed_sharding_op) return llvm::Optional<StringRef>();
|
||||
return ParseShardingAttribute(parsed_sharding_op->getOperation());
|
||||
}
|
||||
|
||||
// If operand of return value of tf_device.LaunchFunc op is directly from
|
||||
// XlaSharding op, return the provided sharding configuration.
|
||||
llvm::Optional<StringRef> ParseReturnValueSharding(FuncOp func,
|
||||
const int output_index,
|
||||
const OpOperand& operand) {
|
||||
if (auto sharding_op =
|
||||
llvm::dyn_cast<TF::XlaShardingOp>(operand.get().getDefiningOp())) {
|
||||
return ParseShardingAttribute(sharding_op.getOperation());
|
||||
}
|
||||
|
||||
return llvm::Optional<StringRef>();
|
||||
}
|
||||
|
||||
// Add parsed sharding configuration to tf_device.LaunchFunc op attribute.
|
||||
void SetShardingConfigurationAsAttribute(
|
||||
tf_device::LaunchFuncOp launch_func, const std::string& attr_name,
|
||||
const llvm::SmallVector<std::string, 8>& sharding_config) {
|
||||
auto input_sharding_array_ref = llvm::SmallVector<llvm::StringRef, 8>(
|
||||
sharding_config.begin(), sharding_config.end());
|
||||
launch_func.setAttr(attr_name,
|
||||
mlir::Builder(launch_func.getContext())
|
||||
.getStrArrayAttr(input_sharding_array_ref));
|
||||
}
|
||||
|
||||
// If XlaSharding op is connected to input/output of the tf_device.LaunchFuncOp,
|
||||
// then add attributes to the op specifying the sharding configurations.
|
||||
void IdentifyXlaShardingForTPUComputation(tf_device::LaunchFuncOp launch_func) {
|
||||
// Look up function definition from module.
|
||||
FuncOp func = launch_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
|
||||
launch_func.func());
|
||||
Block& func_entry_block = func.getBody().getBlocks().front();
|
||||
|
||||
// By default inputs have maximal sharding and inputs are assigned to
|
||||
// logical core 0 if no sharding is defined.
|
||||
llvm::SmallVector<std::string, 8> sharding_for_args(
|
||||
func_entry_block.getNumArguments(),
|
||||
xla::sharding_builder::AssignDevice(0).SerializeAsString());
|
||||
|
||||
// Iterate through input arguments to the entry block of tf_device.LaunchFunc.
|
||||
// For input ops, look for following XlaSharding ops. XlaSharding ops can
|
||||
// 1) Directly follow the input argument if input argument has non-resource
|
||||
// types.
|
||||
// 2) Follow ReadVariableOp if the input type is of resource type.
|
||||
// 3) Follow IdentityOp or CastOp after above cases (1), (2).
|
||||
for (auto& arg_index_and_value :
|
||||
llvm::enumerate(func_entry_block.getArguments())) {
|
||||
const int arg_index = arg_index_and_value.index();
|
||||
auto& arg = arg_index_and_value.value();
|
||||
auto input_arg_sharding = ParseInputSharding(func, arg_index, arg);
|
||||
|
||||
if (!input_arg_sharding.hasValue()) continue;
|
||||
sharding_for_args[arg_index] = input_arg_sharding->str();
|
||||
}
|
||||
SetShardingConfigurationAsAttribute(launch_func, kInputShardingAttr,
|
||||
sharding_for_args);
|
||||
|
||||
// By default return values from logical core 0 is used if no sharding
|
||||
// configuration is defined.
|
||||
llvm::SmallVector<std::string, 8> sharding_for_return_values(
|
||||
func_entry_block.getTerminator()->getNumOperands(),
|
||||
xla::sharding_builder::AssignDevice(0).SerializeAsString());
|
||||
|
||||
// Iterate through operands of the terminator, if the preceding op is
|
||||
// XlaShardingOp, then add provided sharding configuration to launch func
|
||||
// attribute.
|
||||
for (auto& return_value_and_index :
|
||||
llvm::enumerate(func_entry_block.getTerminator()->getOpOperands())) {
|
||||
int return_value_index = return_value_and_index.index();
|
||||
const auto& return_value = return_value_and_index.value();
|
||||
auto return_val_sharding =
|
||||
ParseReturnValueSharding(func, return_value_index, return_value);
|
||||
|
||||
if (return_val_sharding)
|
||||
sharding_for_return_values[return_value_index] =
|
||||
return_val_sharding->str();
|
||||
}
|
||||
SetShardingConfigurationAsAttribute(launch_func, kOutputShardingAttr,
|
||||
sharding_for_return_values);
|
||||
}
|
||||
|
||||
void TPUShardingIdentificationPass::runOnModule() {
|
||||
getModule().walk([&](tf_device::LaunchFuncOp launch_func) {
|
||||
IdentifyXlaShardingForTPUComputation(launch_func);
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPUShardingIdentificationPass() {
|
||||
return std::make_unique<TPUShardingIdentificationPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUShardingIdentificationPass> pass(
|
||||
"tf-tpu-sharding-identification",
|
||||
"Identifies and handles inputs/outputs of TPU computation that is "
|
||||
"sharded across logical cores.");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||
#include "mlir/Support/STLExtras.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
@ -143,8 +144,8 @@ Value SkipIdentity(Value v, bool allow_other_use,
|
||||
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
|
||||
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
TF::TPUExecuteAndUpdateVariablesOp execute, Operation* compile, FuncOp body,
|
||||
FuncOp cond) {
|
||||
TF::TPUExecuteAndUpdateVariablesOp execute,
|
||||
tf_device::LaunchOp compile_launch, FuncOp body, FuncOp cond) {
|
||||
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
|
||||
auto mirrored_variable_indices_attr =
|
||||
replicate.getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
|
||||
@ -168,7 +169,8 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
if (replicate_arg_to_execute_arg.empty()) return mapping;
|
||||
|
||||
// Parse the original compile metadata.
|
||||
auto metadata_str = compile->getAttrOfType<StringAttr>("metadata");
|
||||
Operation& compile = compile_launch.GetBody().front();
|
||||
auto metadata_str = compile.getAttrOfType<StringAttr>("metadata");
|
||||
assert(metadata_str && "Missing compilation metadata");
|
||||
tensorflow::tpu::TPUCompileMetadataProto metadata;
|
||||
metadata.ParseFromString(std::string(metadata_str.getValue()));
|
||||
@ -242,8 +244,8 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
}
|
||||
}
|
||||
// Update the metadata of the compile op.
|
||||
compile->setAttr("metadata", OpBuilder(compile).getStringAttr(
|
||||
metadata.SerializeAsString()));
|
||||
compile.setAttr("metadata", StringAttr::get(metadata.SerializeAsString(),
|
||||
compile.getContext()));
|
||||
return mapping;
|
||||
}
|
||||
|
||||
@ -393,26 +395,56 @@ llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
|
||||
return state_vars;
|
||||
}
|
||||
|
||||
// Wraps single op in `tf_device.launch` for explicit device assignment.
|
||||
void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
|
||||
llvm::StringRef device) {
|
||||
OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
|
||||
|
||||
auto launch = builder->create<tf_device::LaunchOp>(
|
||||
loc, builder->getStringAttr(device), op->getResultTypes());
|
||||
launch.body().push_back(new Block);
|
||||
|
||||
builder->setInsertionPointToEnd(&launch.GetBody());
|
||||
builder->create<tf_device::ReturnOp>(loc, op->getResults());
|
||||
|
||||
// Move op inside launch.
|
||||
op->moveBefore(launch.GetBody().getTerminator());
|
||||
|
||||
builder->restoreInsertionPoint(insert_point);
|
||||
}
|
||||
|
||||
// Performs the transformation for a replicate op inside a while loop.
|
||||
void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
MLIRContext* context) {
|
||||
int64_t num_replicas = replicate.n().getLimitedValue();
|
||||
if (num_replicas == 1) return;
|
||||
TF::TPUExecuteAndUpdateVariablesOp execute;
|
||||
for (auto execute_op :
|
||||
replicate.GetBody().getOps<TF::TPUExecuteAndUpdateVariablesOp>()) {
|
||||
if (execute == nullptr) {
|
||||
execute = execute_op;
|
||||
tf_device::LaunchOp execute_launch;
|
||||
for (auto execute_launch_op :
|
||||
replicate.GetBody().getOps<tf_device::LaunchOp>()) {
|
||||
if (!execute_launch_op.WrapsSingleOp() ||
|
||||
!llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
|
||||
execute_launch_op.GetBody().front()))
|
||||
continue;
|
||||
|
||||
if (execute_launch == nullptr) {
|
||||
execute_launch = execute_launch_op;
|
||||
} else {
|
||||
// We only support one execute op inside replicate.
|
||||
execute = nullptr;
|
||||
execute_launch = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!execute) return;
|
||||
if (!execute_launch) return;
|
||||
auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
|
||||
execute_launch.GetBody().front());
|
||||
auto compile =
|
||||
SkipIdentity(execute.key(), /*allow_other_use=*/true).getDefiningOp();
|
||||
if (!compile) return;
|
||||
auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
|
||||
if (!compile_launch || !compile_launch.WrapsSingleOp() ||
|
||||
compile_launch.GetBody().front().getName().getStringRef() !=
|
||||
"tf._TPUCompileMlir")
|
||||
return;
|
||||
|
||||
auto module = while_op.getParentOfType<ModuleOp>();
|
||||
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
|
||||
@ -421,7 +453,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
// Analyze the formattable inputs.
|
||||
auto execute_arg_to_outer_args =
|
||||
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
|
||||
while_op, replicate, execute, compile, body, cond);
|
||||
while_op, replicate, execute, compile_launch, body, cond);
|
||||
if (execute_arg_to_outer_args.empty()) return;
|
||||
|
||||
// Extract the replicated devices.
|
||||
@ -468,13 +500,15 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
for (const auto& entry : execute_arg_to_outer_args) {
|
||||
reformat_operands.push_back(execute.args()[entry.first]);
|
||||
}
|
||||
reformat_operands.push_back(compile->getResult(1));
|
||||
reformat_operands.push_back(compile_launch.getResult(1));
|
||||
reformat_operands.push_back(replicate.GetBody().getArgument(
|
||||
replicate.GetBody().getNumArguments() - 1));
|
||||
builder.setInsertionPoint(execute);
|
||||
builder.create<TF::TPUReshardVariablesOp>(
|
||||
execute.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands,
|
||||
builder.setInsertionPoint(execute_launch);
|
||||
auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
|
||||
execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands,
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
|
||||
execute_launch.device());
|
||||
|
||||
// Build the replicated unformat op after the loop. First prepare building the
|
||||
// replicate op.
|
||||
@ -514,9 +548,11 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate,
|
||||
unformat_operands.begin() + unformat_operands.size() - 1,
|
||||
default_state_key.getResult());
|
||||
// Unformat op.
|
||||
builder.create<TF::TPUReshardVariablesOp>(
|
||||
auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
|
||||
while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands,
|
||||
llvm::ArrayRef<NamedAttribute>{});
|
||||
WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
|
||||
execute_launch.device());
|
||||
builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
|
||||
@ -35,16 +36,11 @@ limitations under the License.
|
||||
#include "mlir/Support/Functional.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
|
||||
@ -75,7 +71,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
|
||||
Type resultType = RankedTensorType::get(shape, element_type);
|
||||
auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
|
||||
auto shape_tensor =
|
||||
rewriter.create<ConstantOp>(loc, shape_spec_type, constant_attr);
|
||||
rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
|
||||
return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
|
||||
/*shape=*/shape_tensor);
|
||||
}
|
||||
@ -108,8 +104,8 @@ std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
||||
auto begin_attr =
|
||||
DenseElementsAttr::get<int64_t>(vector3_type, {batch_idx, 0, 0});
|
||||
auto size_attr = DenseElementsAttr::get<int64_t>(vector3_type, slice_size);
|
||||
auto begin = rewriter.create<ConstantOp>(loc, vector3_type, begin_attr);
|
||||
auto size = rewriter.create<ConstantOp>(loc, vector3_type, size_attr);
|
||||
auto begin = rewriter.create<TF::ConstOp>(loc, vector3_type, begin_attr);
|
||||
auto size = rewriter.create<TF::ConstOp>(loc, vector3_type, size_attr);
|
||||
auto slice_op = rewriter.create<TF::SliceOp>(loc, slice_result_type,
|
||||
/*input=*/reshape_op.output(),
|
||||
begin, size);
|
||||
@ -312,8 +308,12 @@ PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
||||
}
|
||||
|
||||
static PassRegistration<UnrollBatchMatMulPass> pass(
|
||||
"tfl-unroll-batch-matmul",
|
||||
"tf-unroll-batch-matmul",
|
||||
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
|
||||
|
||||
} // namespace TFL
|
||||
std::unique_ptr<OpPassBase<FuncOp>> CreateUnrollBatchMatMulPassPass() {
|
||||
return std::make_unique<UnrollBatchMatMulPass>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace TF {
|
||||
|
||||
// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not
|
||||
// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape,
|
||||
@ -53,7 +53,7 @@ class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
|
||||
PatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
@ -40,11 +40,13 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Analysis/Verifier.h" // TF:llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
|
||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||
#include "mlir/IR/Identifier.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
@ -63,6 +65,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
|
||||
@ -162,7 +165,8 @@ class ImporterBase {
|
||||
specs_(specs),
|
||||
debug_info_(debug_info),
|
||||
function_name_for_debug_info_(function_name_for_debug_info),
|
||||
function_name_uniquifier_(function_name_uniquifier) {}
|
||||
function_name_uniquifier_(function_name_uniquifier),
|
||||
error_handler_(module.getContext()) {}
|
||||
|
||||
// Returns the inferred function signature of the given function body. Input
|
||||
// types are unranked tensor of the respective datatype in the function and
|
||||
@ -318,8 +322,9 @@ class ImporterBase {
|
||||
// the location.
|
||||
mlir::Location GetLocation(const NodeDef& node);
|
||||
|
||||
// Gets the location information string for the given node.
|
||||
std::string GetLocationStr(const Node& node, bool includeNodeName = false);
|
||||
// Appends the location string for the node to the error message and returns
|
||||
// the combined error status.
|
||||
Status EmitErrorWithLocationStr(const Node& node, const Status& error_status);
|
||||
|
||||
// Inserts a placeholder node in the graph to replace a feed output tensor,
|
||||
// and returns the new placeholder node and a boolean indicating if the
|
||||
@ -368,6 +373,7 @@ class ImporterBase {
|
||||
NodeValueMap node_values_;
|
||||
std::unique_ptr<ShapeRefiner> shape_refiner_;
|
||||
NameUniquifier* function_name_uniquifier_;
|
||||
mlir::StatusScopedDiagnosticHandler error_handler_;
|
||||
|
||||
protected:
|
||||
// Maps feed as TensorId to new Placeholder node name.
|
||||
@ -669,9 +675,10 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
remapped_feeds_[{it->first, index}] = placeholder_node->name();
|
||||
node_name_map[placeholder_node->name()] = placeholder_node;
|
||||
// Add the new placeholder node to the shape refiner.
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
shape_refiner_->AddNode(placeholder_node),
|
||||
GetLocationStr(*placeholder_node));
|
||||
Status status = shape_refiner_->AddNode(placeholder_node);
|
||||
if (!status.ok()) {
|
||||
return EmitErrorWithLocationStr(*placeholder_node, status);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto index_it = it->second.find(0);
|
||||
@ -691,8 +698,10 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
}
|
||||
if (!node_added_to_shape_refiner) {
|
||||
// Add the node to the shape refiner if the node hasn't been removed.
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(shape_refiner_->AddNode(node),
|
||||
GetLocationStr(*node));
|
||||
Status status = shape_refiner_->AddNode(node);
|
||||
if (!status.ok()) {
|
||||
return EmitErrorWithLocationStr(*node, status);
|
||||
}
|
||||
}
|
||||
|
||||
auto set_shape_from_list_attr = [&](const AttrValue* attr) {
|
||||
@ -700,9 +709,11 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
for (auto shape : llvm::enumerate(list.shape())) {
|
||||
auto* node_context = shape_refiner_->GetContext(node);
|
||||
shape_inference::ShapeHandle handle;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
node_context->MakeShapeFromShapeProto(shape.value(), &handle),
|
||||
GetLocationStr(*node));
|
||||
Status status =
|
||||
node_context->MakeShapeFromShapeProto(shape.value(), &handle);
|
||||
if (!status.ok()) {
|
||||
return EmitErrorWithLocationStr(*node, status);
|
||||
}
|
||||
node_context->set_output(shape.index(), handle);
|
||||
}
|
||||
return Status::OK();
|
||||
@ -735,9 +746,11 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
DCHECK(node_context != nullptr);
|
||||
if (const AttrValue* attr = node->attrs().Find("shape")) {
|
||||
shape_inference::ShapeHandle handle;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
node_context->MakeShapeFromShapeProto(attr->shape(), &handle),
|
||||
GetLocationStr(*node));
|
||||
Status status =
|
||||
node_context->MakeShapeFromShapeProto(attr->shape(), &handle);
|
||||
if (!status.ok()) {
|
||||
return EmitErrorWithLocationStr(*node, status);
|
||||
}
|
||||
node_context->set_output(0, handle);
|
||||
} else if (const AttrValue* attr = node->attrs().Find("_output_shapes")) {
|
||||
TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr));
|
||||
@ -813,9 +826,12 @@ Status ImporterBase::AddNodesToShapeRefiner() {
|
||||
existing.push_back(shape_context->output(o));
|
||||
}
|
||||
bool inferred = false;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred),
|
||||
GetLocationStr(*node));
|
||||
shape_inference::ShapeHandle handle;
|
||||
Status status =
|
||||
shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred);
|
||||
if (!status.ok()) {
|
||||
return EmitErrorWithLocationStr(*node, status);
|
||||
}
|
||||
for (int o = 0; o < shape_context->num_outputs(); ++o) {
|
||||
if (!same_inferred_shape(shape_context, shape_context->output(o),
|
||||
existing[o])) {
|
||||
@ -1346,18 +1362,11 @@ mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string ImporterBase::GetLocationStr(const Node& node,
|
||||
bool includeNodeName) {
|
||||
const auto location = GetLocation(node.def());
|
||||
std::string s;
|
||||
llvm::raw_string_ostream ss(s);
|
||||
location.print(ss);
|
||||
ss.flush();
|
||||
// Removes the node name prefix if it exists.
|
||||
if (!s.empty() && s[0] == '\"' && s.find_first_of(node.name()) == 1) {
|
||||
return s.replace(0, node.name().size() + 3, "");
|
||||
}
|
||||
return s;
|
||||
Status ImporterBase::EmitErrorWithLocationStr(const Node& node,
|
||||
const Status& error_status) {
|
||||
const mlir::Location location = GetLocation(node.def());
|
||||
mlir::emitError(location);
|
||||
return error_handler_.Combine(error_status);
|
||||
}
|
||||
|
||||
mlir::Operation* ImporterBase::createOperation(
|
||||
|
@ -210,6 +210,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
bool use_tuple_args, bool return_tuple) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass());
|
||||
tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
|
||||
tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
|
||||
// LegalizeTFControlFlow encapsulates arguments for control flow operations
|
||||
@ -221,6 +222,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
// and canonicalization opportunities that are necessary for the second
|
||||
// LegalizeTFPass(allow_partial_conversion=false) invocation.
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::xla_hlo::createLegalizeTFPass(true));
|
||||
tf2xla.addPass(mlir::tf_executor::CreateTFExecutorGraphPruningPass());
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(
|
||||
mlir::xla_hlo::createLegalizeTFPass(false));
|
||||
|
@ -49,7 +49,9 @@ Status LoadProtoFromFile(absl::string_view input_filename,
|
||||
const auto file_or_err =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(StringViewToRef(input_filename));
|
||||
if (std::error_code error = file_or_err.getError()) {
|
||||
return errors::InvalidArgument("Could not open input file");
|
||||
return errors::InvalidArgument(
|
||||
"Could not open input file ",
|
||||
string(input_filename.data(), input_filename.size()).c_str());
|
||||
}
|
||||
|
||||
const auto& input_file = *file_or_err;
|
||||
|
@ -433,6 +433,7 @@ cc_library(
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":convert_op_folder",
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -198,7 +198,8 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
|
||||
TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
|
||||
TF_ASSIGN_OR_RETURN(auto result_type, ConvertType(instruction->shape()));
|
||||
TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType<RankedTensorType>(
|
||||
instruction->shape(), *builder_));
|
||||
llvm::SmallVector<NamedAttribute, 10> attributes = {builder_->getNamedAttr(
|
||||
"name", builder_->getStringAttr(instruction->name()))};
|
||||
mlir::Location loc = func_builder->getUnknownLoc();
|
||||
@ -699,29 +700,12 @@ StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
|
||||
return operands;
|
||||
}
|
||||
|
||||
StatusOr<mlir::Type> HloFunctionImporter::ConvertType(const Shape& shape) {
|
||||
if (shape.IsToken()) {
|
||||
return mlir::xla_hlo::TokenType::get(builder_->getContext());
|
||||
}
|
||||
if (shape.IsTuple()) {
|
||||
llvm::SmallVector<mlir::Type, 4> contents;
|
||||
contents.reserve(shape.tuple_shapes_size());
|
||||
for (const auto& subtype : shape.tuple_shapes()) {
|
||||
TF_ASSIGN_OR_RETURN(auto mlir_subtype, ConvertType(subtype));
|
||||
contents.push_back(mlir_subtype);
|
||||
}
|
||||
|
||||
return builder_->getTupleType(contents);
|
||||
}
|
||||
|
||||
return ConvertTensorShapeToType<RankedTensorType>(shape, *builder_);
|
||||
}
|
||||
|
||||
tensorflow::Status HloFunctionImporter::GetMlirTypes(
|
||||
const std::vector<HloInstruction*>& instructions,
|
||||
llvm::SmallVectorImpl<mlir::Type>* types) {
|
||||
for (auto instruction : instructions) {
|
||||
TF_ASSIGN_OR_RETURN(auto ret_type, ConvertType(instruction->shape()));
|
||||
TF_ASSIGN_OR_RETURN(auto ret_type, ConvertShapeToType<RankedTensorType>(
|
||||
instruction->shape(), *builder_));
|
||||
types->push_back(ret_type);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
|
@ -78,9 +78,6 @@ class HloFunctionImporter {
|
||||
// Converts xla Tensor type to the corresponding MLIR type.
|
||||
StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape);
|
||||
|
||||
// Converts xla Primitive types to the corresponding MLIR type.
|
||||
StatusOr<mlir::Type> ConvertType(const xla::Shape& shape);
|
||||
|
||||
// Returns the output type of an HloInstruction.
|
||||
StatusOr<mlir::Type> GetReturnType(xla::HloInstruction* instruction);
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
||||
namespace xla {
|
||||
@ -69,6 +70,9 @@ static StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
|
||||
}
|
||||
return builder.getTupleType(contents);
|
||||
}
|
||||
if (shape.IsToken()) {
|
||||
return mlir::xla_hlo::TokenType::get(builder.getContext());
|
||||
}
|
||||
return ConvertTensorShapeToType<TypeT>(shape, builder);
|
||||
}
|
||||
|
||||
|
@ -32,11 +32,13 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||
#include "mlir/IR/Attributes.h" // TF:llvm-project
|
||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||
#include "mlir/IR/Dialect.h" // TF:llvm-project
|
||||
#include "mlir/IR/Location.h" // TF:llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||
#include "mlir/IR/Matchers.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpDefinition.h" // TF:llvm-project
|
||||
#include "mlir/IR/OpImplementation.h" // TF:llvm-project
|
||||
#include "mlir/IR/Operation.h" // TF:llvm-project
|
||||
@ -46,6 +48,7 @@ limitations under the License.
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
|
||||
#include "mlir/IR/Types.h" // TF:llvm-project
|
||||
#include "mlir/IR/Value.h" // TF:llvm-project
|
||||
#include "mlir/Support/LLVM.h" // TF:llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
|
||||
#include "mlir/Transforms/InliningUtils.h" // TF:llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
|
||||
@ -479,6 +482,48 @@ static LogicalResult Verify(BroadcastInDimOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScalarsToDimensionTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
|
||||
// Canonicalizes the pattern of the form
|
||||
//
|
||||
// %2 = "xla_hlo.scalars_to_dimension_tensor"(%0, %1)
|
||||
// : (i32, i32) -> tensor<2xi32>
|
||||
// %3 = extract_element %2[%c0] : tensor<2xi32>
|
||||
//
|
||||
// to just %0.
|
||||
struct ExtractElementFromScalarsToDimensionTensor
|
||||
: public OpRewritePattern<ExtractElementOp> {
|
||||
using OpRewritePattern<ExtractElementOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(ExtractElementOp extract,
|
||||
PatternRewriter& rewriter) const override {
|
||||
if (extract.indices().size() != 1) return matchFailure();
|
||||
|
||||
if (auto scalars_to_tensor = dyn_cast_or_null<ScalarsToDimensionTensorOp>(
|
||||
extract.aggregate().getDefiningOp())) {
|
||||
APInt index;
|
||||
if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) {
|
||||
return matchFailure();
|
||||
}
|
||||
rewriter.replaceOp(extract,
|
||||
scalars_to_tensor.getOperand(index.getZExtValue()));
|
||||
return matchSuccess();
|
||||
}
|
||||
return matchFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void ScalarsToDimensionTensorOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<ExtractElementFromScalarsToDimensionTensor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicBroadcastInDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -784,11 +784,12 @@ def HLO_ScalarsToDimensionTensorOp : HLO_Op<"scalars_to_dimension_tensor",
|
||||
compute shape arguments to dynamic operations.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnySignlessInteger>);
|
||||
let arguments = (ins Variadic<AnySignlessInteger>:$scalars);
|
||||
let results = (outs HLO_DimensionTensor);
|
||||
|
||||
// Cannot be exported to legacy formats.
|
||||
let hasCustomHLOConverter = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
|
||||
|
@ -64,3 +64,13 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extract_scalars_to_tensor
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
|
||||
func @extract_scalars_to_tensor(%arg0: i32, %arg1: i32) -> i32 {
|
||||
%0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32>
|
||||
%1 = constant 0 : index
|
||||
%2 = extract_element %0[%1] : tensor<2xi32>
|
||||
// CHECK: return %[[ARG0]]
|
||||
return %2 : i32
|
||||
}
|
||||
|
@ -453,6 +453,14 @@ func @rng_uniform_invalid_type(%mu: tensor<complex<f32>>, %sigma: tensor<f32>) -
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @scalars_to_dimension_tensor
|
||||
func @scalars_to_dimension_tensor(%arg0: i32, %arg1: i32) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.scalars_to_dimension_tensor"(%arg0, %arg1) : (i32, i32) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
|
@ -35,7 +35,7 @@ from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def GenerateNumpyRandomRGB(shape):
|
||||
def _generate_numpy_random_rgb(shape):
|
||||
# Only generate floating points that are fractions like n / 256, since they
|
||||
# are RGB pixels. Some low-precision floating point types in this test can't
|
||||
# handle arbitrary precision floating points well.
|
||||
@ -51,7 +51,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
|
||||
shape = (batch_size, 2, 7, 3)
|
||||
|
||||
for nptype in self.float_types:
|
||||
inp = GenerateNumpyRandomRGB(shape).astype(nptype)
|
||||
inp = _generate_numpy_random_rgb(shape).astype(nptype)
|
||||
|
||||
# Convert to HSV and back, as a batch and individually
|
||||
with self.session() as sess:
|
||||
@ -89,7 +89,7 @@ class RGBToHSVTest(xla_test.XLATestCase):
|
||||
def testRGBToHSVNumpy(self):
|
||||
"""Tests the RGB to HSV conversion matches a reference implementation."""
|
||||
for nptype in self.float_types:
|
||||
rgb_flat = GenerateNumpyRandomRGB((64, 3)).astype(nptype)
|
||||
rgb_flat = _generate_numpy_random_rgb((64, 3)).astype(nptype)
|
||||
rgb_np = rgb_flat.reshape(4, 4, 4, 3)
|
||||
hsv_np = np.array([
|
||||
colorsys.rgb_to_hsv(
|
||||
|
@ -153,6 +153,10 @@ class TRTEngineOp : public AsyncOpKernel {
|
||||
// Whether to use implicit batch dimension for TensorRT.
|
||||
bool use_implicit_batch_;
|
||||
|
||||
// Whether to collect optimization profiles for TensorRT, only used when
|
||||
// use_implicit_batch_=false.
|
||||
bool profile_generation_mode_;
|
||||
|
||||
// Whether to build TensorRT engines at runtime.
|
||||
bool allow_build_at_runtime_;
|
||||
|
||||
@ -322,7 +326,19 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
|
||||
use_implicit_batch_ = true;
|
||||
}
|
||||
#endif
|
||||
status =
|
||||
context->GetAttr("_profile_generation_mode", &profile_generation_mode_);
|
||||
if (status.code() == tensorflow::error::NOT_FOUND) {
|
||||
VLOG(2) << "Not found _profile_generation_mode in "
|
||||
<< context->device()->name()
|
||||
<< ", thus setting _profile_generation_mode=false";
|
||||
profile_generation_mode_ = false;
|
||||
}
|
||||
if (use_implicit_batch_) {
|
||||
OP_REQUIRES(context, !profile_generation_mode_,
|
||||
errors::InvalidArgument(
|
||||
"profile_generation_mode_=true is only supported if "
|
||||
"use_implicit_batch=false"));
|
||||
if (input_partial_shapes_.empty()) {
|
||||
VLOG(1) << "Attribute input_shapes is not set. This happens probably "
|
||||
<< "because you are using a model that is already converted "
|
||||
@ -547,11 +563,21 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
|
||||
OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_concrete_shapes), *helper);
|
||||
|
||||
if (!use_implicit_batch_) {
|
||||
if (cache_res->profiles_.GetNumProfiles() == 0) {
|
||||
// Create a single profile from the current input shape. In the future we
|
||||
// will collect a set of input shapes during build mode and create
|
||||
// profiles for each of them.
|
||||
if (profile_generation_mode_) {
|
||||
// Collecting new shapes for profiles can be only done once. After the
|
||||
// shapes are converted to TRT profiles, no shapes can be collected
|
||||
// anymore.
|
||||
OP_REQUIRES(ctx, cache_res->profiles_.GetNumProfiles() == 0,
|
||||
errors::Unimplemented("Cannot collect new shapes when "
|
||||
"profiles are already created."));
|
||||
// Just collect the input shape info and return. The shapes are used to
|
||||
// generate optimization profiles during engine creation.
|
||||
cache_res->profiles_.AddShape(input_concrete_shapes);
|
||||
VLOG(1) << "Native segment is used during collecting shapes for profiles";
|
||||
ExecuteNativeSegment(ctx, helper);
|
||||
return;
|
||||
} else if (cache_res->profiles_.GetNumProfiles() == 0) {
|
||||
// Create profiles out of collected shapes during profile generation.
|
||||
cache_res->profiles_.InitProfiles();
|
||||
}
|
||||
}
|
||||
|
@ -238,12 +238,16 @@ TEST_F(TRTEngineOpTestBase, ExplicitBatch) {
|
||||
device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
|
||||
core::ScopedUnref sc(cache_resource);
|
||||
|
||||
// The cache should contain only one EngineContext, with a valid cuda_engine.
|
||||
// Due to the way the engine lookup is implemented, explicit batch mode
|
||||
// requires profile generation. Currently profile generaton is not enabled in
|
||||
// this test therfore engine creation fails.
|
||||
//
|
||||
// TODO(Tamas) find a way to enable profile generation mode and test it
|
||||
auto cache = &cache_resource->cache_;
|
||||
EXPECT_EQ(1, cache->size());
|
||||
ASSERT_EQ(1, cache->count({input_shape}));
|
||||
EngineContext* ectx = cache->at({input_shape}).get();
|
||||
EXPECT_NE(ectx->cuda_engine, nullptr);
|
||||
EXPECT_EQ(0, cache->size());
|
||||
// ASSERT_EQ(1, cache->count({input_shape}));
|
||||
// EngineContext* ectx = cache->at({input_shape}).get();
|
||||
// EXPECT_NE(ectx->cuda_engine, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TRTEngineOpTestBase, DynamicShapes) {
|
||||
@ -267,12 +271,13 @@ TEST_F(TRTEngineOpTestBase, DynamicShapes) {
|
||||
device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource));
|
||||
core::ScopedUnref sc(cache_resource);
|
||||
|
||||
// The cache should contain only one EngineContext.
|
||||
// We did not have profile generation mode therfore engine creation failed.
|
||||
// TODO(Tamas) find a way to enable profile generation mode and test it
|
||||
auto cache = &cache_resource->cache_;
|
||||
EXPECT_EQ(1, cache->size());
|
||||
ASSERT_EQ(1, cache->count({input_shape}));
|
||||
EngineContext* ectx = cache->at({input_shape}).get();
|
||||
EXPECT_NE(ectx->cuda_engine, nullptr);
|
||||
EXPECT_EQ(0, cache->size());
|
||||
// ASSERT_EQ(1, cache->count({input_shape}));
|
||||
// EngineContext* ectx = cache->at({input_shape}).get();
|
||||
// EXPECT_NE(ectx->cuda_engine, nullptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -70,7 +70,15 @@ void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment,
|
||||
// TODO(aaroey): AllocateRaw takes size_t size as input, so it'll produce
|
||||
// unexpected result when TRT tries to allocate more bytes than size_t can
|
||||
// carry. Fix this.
|
||||
void* mem = allocator_->AllocateRaw(alignment, total_size);
|
||||
//
|
||||
// Fail immediately if allocation fails, rather than waiting 10 seconds and
|
||||
// failing then anyway.
|
||||
// TensorRT 7 can also switch to a different algorithm for a layer if an
|
||||
// algorithm uses too much memory. If we don't fail immediately building the
|
||||
// engine can be *very* slow with TensorRT7 when GPU memory is limited.
|
||||
AllocationAttributes attributes;
|
||||
attributes.no_retry_on_failure = true;
|
||||
void* mem = allocator_->AllocateRaw(alignment, total_size, attributes);
|
||||
if (!mem) return nullptr;
|
||||
|
||||
void* alloc_mem = mem;
|
||||
|
@ -1,4 +1,4 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test", "tf_openmp_copts")
|
||||
load(
|
||||
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
@ -11,6 +11,7 @@ load(
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
|
||||
load("//tensorflow/compiler/xla/service/cpu:build_defs.bzl", "runtime_copts")
|
||||
|
||||
package(
|
||||
default_visibility = [":internal"],
|
||||
@ -180,6 +181,76 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# The filegroups below are explicitly used by
|
||||
# tensorflow/tools/pip_package:build_pip_package to ensure we include the proper
|
||||
# sources for the XLA AOT CPU runtime; as these are necessary outside of bazel
|
||||
# when linking tfcompile objects using saved_model_cli (e.g. using the
|
||||
# tensorflow pip package). The associated .cc files are included in tensorflow
|
||||
# pip package's xla_aot_runtime_srcs/ subdirectory. All necessary headers are
|
||||
# also included in the pip package's include/tensorflow/ and include/external/
|
||||
# subdirectories. Note however that sometimes additional object files may need
|
||||
# to be linked when linking aot xla objects, e.g. abseil libraries. See the deps
|
||||
# attribute of the "xla_compiled_cpu_runtime_standalone" target below for an
|
||||
# exhaustive list.
|
||||
filegroup(
|
||||
name = "xla_compiled_cpu_runtime_hdrs",
|
||||
srcs = [
|
||||
"xla_compiled_cpu_function.h",
|
||||
"//tensorflow/compiler/xla:cpu_runtime_hdrs",
|
||||
"//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_hdrs",
|
||||
"//tensorflow/core/kernels:xla_cpu_runtime_hdrs",
|
||||
"//tensorflow/core/platform:xla_cpu_runtime_srcs",
|
||||
],
|
||||
visibility = ["//tensorflow/tools/pip_package:__pkg__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "xla_compiled_cpu_runtime_srcs",
|
||||
srcs = [
|
||||
"xla_compiled_cpu_function.cc",
|
||||
"//tensorflow/compiler/xla:cpu_runtime_srcs",
|
||||
"//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_srcs",
|
||||
"//tensorflow/core/kernels:xla_cpu_runtime_srcs",
|
||||
"//tensorflow/core/platform:xla_cpu_runtime_srcs",
|
||||
],
|
||||
visibility = ["//tensorflow/tools/pip_package:__pkg__"],
|
||||
)
|
||||
|
||||
# This stand-alone target is used to ensure that we can build tf_library type
|
||||
# targets against the subset of sources declared in
|
||||
# xla_compiled_cpu_runtime_{srcs,hdrs}.
|
||||
#
|
||||
# The macros in tensorflow/python/tools/tools.bzl produce AOT compiled binaries
|
||||
# that rely on this target, as do unit tests in tensorflow/python/tools.
|
||||
#
|
||||
# See above for the significance of the source filegroups.
|
||||
cc_library(
|
||||
name = "xla_compiled_cpu_runtime_standalone",
|
||||
srcs = [
|
||||
":xla_compiled_cpu_runtime_srcs",
|
||||
],
|
||||
hdrs = [
|
||||
":xla_compiled_cpu_runtime_hdrs",
|
||||
],
|
||||
copts = runtime_copts() + tf_openmp_copts(),
|
||||
features = ["fully_static_link"],
|
||||
linkstatic = 1,
|
||||
visibility = [":friends"],
|
||||
# Note, we specifically remove MKL dependencies so the standalone does
|
||||
# not require the MKL binary blob.
|
||||
deps = [
|
||||
"//tensorflow/core/framework:numeric_types",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:cord",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_compiled_cpu_function",
|
||||
srcs = ["xla_compiled_cpu_function.cc"],
|
||||
@ -190,7 +261,7 @@ cc_library(
|
||||
# binary produced by tfcompile.
|
||||
"//tensorflow/compiler/xla:cpu_function_runtime",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -36,6 +36,25 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "cpu_runtime_srcs",
|
||||
srcs = [
|
||||
"cpu_function_runtime.cc",
|
||||
"executable_run_options.cc",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "cpu_runtime_hdrs",
|
||||
srcs = [
|
||||
"cpu_function_runtime.h",
|
||||
"executable_run_options.h",
|
||||
"types.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "xla_data_proto",
|
||||
srcs = ["xla_data.proto"],
|
||||
@ -142,7 +161,8 @@ cc_library(
|
||||
hdrs = ["types.h"],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/framework:numeric_types",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -620,7 +640,6 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":types",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -896,7 +915,10 @@ cc_library(
|
||||
srcs = ["cpu_function_runtime.cc"],
|
||||
hdrs = ["cpu_function_runtime.h"],
|
||||
visibility = [":friends"],
|
||||
deps = ["//tensorflow/core:framework_lite"],
|
||||
deps = [
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
|
@ -271,8 +271,7 @@ static ShapedBuffer MaybeOwningShapeTreeToShapedBuffer(
|
||||
|
||||
StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
|
||||
absl::Span<Shape const* const> argument_host_shapes,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
ExecutableRunOptions run_options) {
|
||||
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options) {
|
||||
if (argument_host_shapes.size() != arguments.size()) {
|
||||
return InvalidArgument(
|
||||
"Number of argument host shapes not equal to number of arguments (%d "
|
||||
@ -291,8 +290,8 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
|
||||
shaped_buffer_ptrs.reserve(arguments.size());
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
shaped_buffers.push_back(MaybeOwningShapeTreeToShapedBuffer(
|
||||
*argument_host_shapes[i], arguments[i], backend_->platform(),
|
||||
stream->parent()->device_ordinal()));
|
||||
*argument_host_shapes[i], arguments[i].Buffers(),
|
||||
backend_->platform(), stream->parent()->device_ordinal()));
|
||||
shaped_buffer_ptrs.push_back(&shaped_buffers.back());
|
||||
}
|
||||
|
||||
|
@ -61,8 +61,7 @@ class LocalExecutable {
|
||||
// executable.
|
||||
StatusOr<ExecutionOutput> RunAsync(
|
||||
absl::Span<Shape const* const> argument_host_shapes,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
ExecutableRunOptions run_options);
|
||||
std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
|
||||
|
||||
// Return the options used to build the executable.
|
||||
const ExecutableBuildOptions& build_options() const { return build_options_; }
|
||||
@ -76,8 +75,8 @@ class LocalExecutable {
|
||||
//
|
||||
// The given ExecutableRunOptions override any values from TF_XLA_FLAGS
|
||||
// environment variable.
|
||||
Status ValidateExecutionOptions(
|
||||
const ExecutableRunOptions& run_options, const Backend& backend);
|
||||
Status ValidateExecutionOptions(const ExecutableRunOptions& run_options,
|
||||
const Backend& backend);
|
||||
|
||||
// Returns a literal containing the contents of the given ShapedBuffer.
|
||||
StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
|
||||
|
@ -70,45 +70,6 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator,
|
||||
entry->set_id(id);
|
||||
entry->set_name(GetFullName(base_name, separator, id));
|
||||
}
|
||||
|
||||
template <typename InstructionType>
|
||||
StatusOr<InstructionType> LookUpInstructionByHandleInternal(
|
||||
const absl::flat_hash_map<int64, int64>& handle_to_index,
|
||||
const std::vector<HloInstructionProto>& instructions, int64 handle) {
|
||||
auto it = handle_to_index.find(handle);
|
||||
if (it == handle_to_index.end()) {
|
||||
return InvalidArgument("No XlaOp with handle %d", handle);
|
||||
}
|
||||
return const_cast<InstructionType>(&instructions.at(it->second));
|
||||
}
|
||||
|
||||
Status CheckBuildersAffinity(const XlaBuilder* op_builder,
|
||||
const XlaBuilder* builder, int64 handle) {
|
||||
if (op_builder != builder) {
|
||||
return InvalidArgument(
|
||||
"XlaOp with handle %d is built by builder '%s', but is trying to use "
|
||||
"it in builder '%s'",
|
||||
handle, op_builder->name(), builder->name());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename InstructionType, typename OpBuilderType,
|
||||
typename BuilderType, typename OpType>
|
||||
StatusOr<InstructionType> LookUpInstructionInternal(
|
||||
const absl::flat_hash_map<int64, int64>& handle_to_index,
|
||||
const std::vector<HloInstructionProto>& instructions,
|
||||
OpBuilderType op_builder, BuilderType builder, OpType op_handle) {
|
||||
if (op_builder == nullptr) {
|
||||
return InvalidArgument(
|
||||
"Invalid XlaOp with handle %d; the builder of this op is freed",
|
||||
op_handle);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckBuildersAffinity(op_builder, builder, op_handle));
|
||||
return LookUpInstructionByHandleInternal<InstructionType>(
|
||||
handle_to_index, instructions, op_handle);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp operator-(XlaOp x) { return Neg(x); }
|
||||
@ -143,7 +104,7 @@ XlaOp operator>>(XlaOp x, XlaOp y) {
|
||||
|
||||
StatusOr<const Shape*> XlaBuilder::GetShapePtr(XlaOp op) const {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
TF_RETURN_IF_ERROR(CheckBuildersAffinity(op.builder(), this, op.handle()));
|
||||
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
|
||||
auto it = handle_to_index_.find(op.handle());
|
||||
if (it == handle_to_index_.end()) {
|
||||
return InvalidArgument("No XlaOp with handle %d", op.handle());
|
||||
@ -1941,6 +1902,16 @@ XlaOp XlaBuilder::ConditionalImpl(
|
||||
});
|
||||
}
|
||||
|
||||
Status XlaBuilder::CheckOpBuilder(XlaOp op) const {
|
||||
if (this != op.builder()) {
|
||||
return InvalidArgument(
|
||||
"XlaOp with handle %d is built by builder '%s', but is trying to use "
|
||||
"it in builder '%s'",
|
||||
op.handle(), op.builder()->name(), name());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Reduce(XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
@ -2886,27 +2857,23 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
|
||||
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
|
||||
const XlaOp op) const {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
return LookUpInstructionInternal<const HloInstructionProto*>(
|
||||
handle_to_index_, instructions_, op.builder_, this, op.handle());
|
||||
return LookUpInstructionInternal<const HloInstructionProto*>(op);
|
||||
}
|
||||
|
||||
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
|
||||
int64 handle) const {
|
||||
return LookUpInstructionByHandleInternal<const HloInstructionProto*>(
|
||||
handle_to_index_, instructions_, handle);
|
||||
return LookUpInstructionByHandleInternal<const HloInstructionProto*>(handle);
|
||||
}
|
||||
|
||||
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstruction(
|
||||
const XlaOp op) {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
return LookUpInstructionInternal<HloInstructionProto*>(
|
||||
handle_to_index_, instructions_, op.builder_, this, op.handle());
|
||||
return LookUpInstructionInternal<HloInstructionProto*>(op);
|
||||
}
|
||||
|
||||
StatusOr<HloInstructionProto*> XlaBuilder::LookUpMutableInstructionByHandle(
|
||||
int64 handle) {
|
||||
return LookUpInstructionByHandleInternal<HloInstructionProto*>(
|
||||
handle_to_index_, instructions_, handle);
|
||||
return LookUpInstructionByHandleInternal<HloInstructionProto*>(handle);
|
||||
}
|
||||
|
||||
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
||||
|
@ -1061,6 +1061,34 @@ class XlaBuilder {
|
||||
XlaOp branch_index,
|
||||
absl::Span<const XlaComputation* const> branch_computations,
|
||||
absl::Span<const XlaOp> branch_operands);
|
||||
|
||||
// Returns OK status if the given op was built using this builder. Otherwise,
|
||||
// returns an error.
|
||||
Status CheckOpBuilder(XlaOp op) const;
|
||||
|
||||
// Here, InstructionType is either const HloInstructionProto* or non-const
|
||||
// HloInstructionProto*.
|
||||
template <typename InstructionType>
|
||||
StatusOr<InstructionType> LookUpInstructionByHandleInternal(
|
||||
int64 handle) const {
|
||||
auto it = handle_to_index_.find(handle);
|
||||
if (it == handle_to_index_.end()) {
|
||||
return InvalidArgument("No XlaOp with handle %d", handle);
|
||||
}
|
||||
return const_cast<InstructionType>(&instructions_.at(it->second));
|
||||
}
|
||||
|
||||
// Here, InstructionType is either const HloInstructionProto* or non-const
|
||||
// HloInstructionProto*.
|
||||
//
|
||||
// TODO(hinsu): Return const pointer within StatusOr and use
|
||||
// absl::implicit_cast at callsites. This requires implicit_cast support in
|
||||
// stream_executor::port::StatusOr similar to absl::StatusOr.
|
||||
template <typename InstructionType>
|
||||
StatusOr<InstructionType> LookUpInstructionInternal(XlaOp op) const {
|
||||
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
|
||||
return LookUpInstructionByHandleInternal<InstructionType>(op.handle());
|
||||
}
|
||||
};
|
||||
|
||||
// RAII-style object: sets the current sharding assignment in builder on
|
||||
|
@ -17,8 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
RunId::RunId() {
|
||||
@ -28,7 +26,9 @@ RunId::RunId() {
|
||||
|
||||
bool operator==(const RunId& a, const RunId& b) { return a.data_ == b.data_; }
|
||||
|
||||
std::string RunId::ToString() const { return absl::StrCat("RunId: ", data_); }
|
||||
std::string RunId::ToString() const {
|
||||
return "RunId: " + std::to_string(data_);
|
||||
}
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
|
||||
int device_ordinal) {
|
||||
|
1
tensorflow/compiler/xla/service/BUILD
Executable file → Normal file
1
tensorflow/compiler/xla/service/BUILD
Executable file → Normal file
@ -3109,6 +3109,7 @@ cc_library(
|
||||
deps = [
|
||||
":heap_simulator",
|
||||
":hlo_cost_analysis",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,32 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "single_threaded_runtime_srcs",
|
||||
srcs = [
|
||||
"runtime_fp16.cc",
|
||||
"runtime_key_value_sort.cc",
|
||||
"runtime_single_threaded_conv2d.cc",
|
||||
"runtime_single_threaded_fft.cc",
|
||||
"runtime_single_threaded_matmul.cc",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "single_threaded_runtime_hdrs",
|
||||
srcs = [
|
||||
"runtime_conv2d_impl.h",
|
||||
"runtime_fft_impl.h",
|
||||
"runtime_fp16.h",
|
||||
"runtime_key_value_sort.h",
|
||||
"runtime_single_threaded_conv2d.h",
|
||||
"runtime_single_threaded_fft.h",
|
||||
"runtime_single_threaded_matmul.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_transfer_manager",
|
||||
srcs = ["cpu_transfer_manager.cc"],
|
||||
@ -219,7 +245,8 @@ cc_library(
|
||||
],
|
||||
copts = runtime_copts(),
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/platform:macros",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
)
|
||||
|
||||
@ -545,8 +572,10 @@ cc_library(
|
||||
deps = [
|
||||
":runtime_lightweight_check",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/kernels:eigen_helpers",
|
||||
"//tensorflow/core/kernels:eigen_helpers_no_mkl",
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -563,7 +592,8 @@ cc_library(
|
||||
":runtime_conv2d",
|
||||
":runtime_single_threaded_conv2d",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core/kernels:eigen_helpers",
|
||||
"//third_party/eigen3",
|
||||
] + mkl_deps(),
|
||||
@ -581,8 +611,10 @@ cc_library(
|
||||
deps = [
|
||||
":runtime_lightweight_check",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/framework:numeric_types",
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -596,8 +628,10 @@ cc_library(
|
||||
deps = [
|
||||
":runtime_lightweight_check",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/kernels:eigen_contraction_kernel",
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:mutex",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -610,7 +644,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
] + mkl_deps(),
|
||||
)
|
||||
@ -626,8 +660,9 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":runtime_lightweight_check",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/kernels:eigen_helpers",
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -643,7 +678,9 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/framework:numeric_types",
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -655,8 +692,9 @@ cc_library(
|
||||
copts = runtime_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/kernels:eigen_contraction_kernel",
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -668,7 +706,9 @@ cc_library(
|
||||
copts = runtime_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:macros",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -681,8 +721,10 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/platform:blocking_counter",
|
||||
"//tensorflow/core/platform:dynamic_annotations",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
@ -711,6 +753,23 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "runtime_fft_test",
|
||||
srcs = [
|
||||
"runtime_fft_impl.h",
|
||||
"runtime_fft_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":runtime_single_threaded_fft",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/framework:numeric_types",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "cpu_instruction_fusion_test",
|
||||
srcs = ["cpu_instruction_fusion_test.cc"],
|
||||
|
@ -78,9 +78,9 @@ CpuExecutable::CpuExecutable(
|
||||
StatusOr<std::tuple<std::vector<se::DeviceMemoryBase>,
|
||||
std::vector<se::OwningDeviceMemory>,
|
||||
std::vector<se::OwningDeviceMemory>>>
|
||||
CpuExecutable::CreateBufferTable(
|
||||
se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments) {
|
||||
CpuExecutable::CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
|
||||
int device_ordinal,
|
||||
std::vector<ExecutionInput> arguments) {
|
||||
std::vector<se::DeviceMemoryBase> unowning_buffers(
|
||||
assignment_->Allocations().size());
|
||||
std::vector<se::OwningDeviceMemory> owning_buffers(
|
||||
@ -95,7 +95,7 @@ CpuExecutable::CreateBufferTable(
|
||||
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
unowning_buffers[i] = arguments[allocation.parameter_number()]
|
||||
.element(allocation.param_shape_index())
|
||||
.Buffer(allocation.param_shape_index())
|
||||
.AsDeviceMemoryBase();
|
||||
CHECK_EQ(allocation.size(), unowning_buffers[i].size())
|
||||
<< "Size mismatch on param " << allocation.parameter_number()
|
||||
@ -139,9 +139,9 @@ CpuExecutable::CreateBufferTable(
|
||||
VLOG(3) << "result index: " << result_slice.index();
|
||||
|
||||
std::vector<se::OwningDeviceMemory> buffers_to_free;
|
||||
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
|
||||
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
|
||||
auto maybe_owning_buffer = buffer.second.Release();
|
||||
for (auto& argument : arguments) {
|
||||
for (auto& index_buffer : *argument.MutableBuffers()) {
|
||||
auto maybe_owning_buffer = index_buffer.second.Release();
|
||||
if (maybe_owning_buffer) {
|
||||
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
|
||||
}
|
||||
@ -284,7 +284,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
||||
|
||||
StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
if (GetRootValueSet().IsAmbiguous()) {
|
||||
return Unimplemented("Points-to set of root instruction is ambiguous");
|
||||
@ -297,7 +297,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||
for (int64 i = 0; i < entry_comp->num_parameters(); ++i) {
|
||||
const Shape& expected_shape =
|
||||
entry_comp->parameter_instruction(i)->shape();
|
||||
const Shape& actual_shape = arguments[i].shape();
|
||||
const Shape& actual_shape = arguments[i].Buffers().shape();
|
||||
CHECK(
|
||||
Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape))
|
||||
<< absl::StreamFormat(
|
||||
@ -355,9 +355,7 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||
std::make_shared<std::vector<se::OwningDeviceMemory>>(
|
||||
std::move(owning_buffers)),
|
||||
hlo_execution_profile});
|
||||
|
||||
return ExecutionOutput(std::move(result), std::move(buffers_to_release), {},
|
||||
se::OwningDeviceMemory());
|
||||
return ExecutionOutput(std::move(result), std::move(buffers_to_release));
|
||||
}
|
||||
|
||||
/*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) {
|
||||
|
@ -57,7 +57,7 @@ class CpuExecutable : public Executable {
|
||||
|
||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) override;
|
||||
|
||||
// This should be called after set_ir_module_string.
|
||||
@ -103,8 +103,7 @@ class CpuExecutable : public Executable {
|
||||
std::vector<se::OwningDeviceMemory>,
|
||||
std::vector<se::OwningDeviceMemory>>>
|
||||
CreateBufferTable(se::DeviceMemoryAllocator* memory_allocator,
|
||||
int device_ordinal,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments);
|
||||
int device_ordinal, std::vector<ExecutionInput> arguments);
|
||||
|
||||
// Calls the generated function performing the computation with the given
|
||||
// arguments using the supplied buffers.
|
||||
|
@ -33,7 +33,8 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft(
|
||||
const xla::ExecutableRunOptions* run_options =
|
||||
static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
|
||||
XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
|
||||
tensorflow::xla::EigenFftImpl(*run_options->intra_op_thread_pool(), out,
|
||||
operand, fft_type, fft_rank, input_batch,
|
||||
fft_length0, fft_length1, fft_length2);
|
||||
tensorflow::xla::EigenFftImpl(
|
||||
*run_options->intra_op_thread_pool(), out, operand,
|
||||
static_cast<tensorflow::xla::FftType>(fft_type), fft_rank, input_batch,
|
||||
fft_length0, fft_length1, fft_length2);
|
||||
}
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -28,6 +27,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace xla {
|
||||
|
||||
enum class FftType : int32 {
|
||||
FFT = 0, // Forward FFT; complex in, complex out.
|
||||
IFFT = 1, // Inverse FFT; complex in, complex out.
|
||||
RFFT = 2, // Forward real FFT; real in, fft_length / 2 + 1 complex out
|
||||
IRFFT = 3, // Inverse real FFT; fft_length / 2 + 1 complex in,
|
||||
// fft_length real out
|
||||
};
|
||||
static constexpr int kFftTypeArraySize = 4;
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Computes either a forward or reverse complex-to-complex FFT.
|
||||
@ -170,27 +178,27 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand,
|
||||
|
||||
template <int FFTRank, typename EigenDevice>
|
||||
void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
|
||||
int32 fft_type, int64 input_batch, int64 fft_length0,
|
||||
FftType fft_type, int64 input_batch, int64 fft_length0,
|
||||
int64 fft_length1, int64 fft_length2) {
|
||||
switch (fft_type) {
|
||||
case ::xla::FftType::FFT:
|
||||
case FftType::FFT:
|
||||
EigenFftC2C<true, FFTRank, EigenDevice>(
|
||||
device, static_cast<complex64*>(out),
|
||||
static_cast<complex64*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
break;
|
||||
case ::xla::FftType::IFFT:
|
||||
case FftType::IFFT:
|
||||
EigenFftC2C<false, FFTRank, EigenDevice>(
|
||||
device, static_cast<complex64*>(out),
|
||||
static_cast<complex64*>(operand), input_batch, fft_length0,
|
||||
fft_length1, fft_length2);
|
||||
break;
|
||||
case ::xla::FftType::RFFT:
|
||||
case FftType::RFFT:
|
||||
EigenFftR2C<FFTRank, EigenDevice>(
|
||||
device, static_cast<complex64*>(out), static_cast<float*>(operand),
|
||||
input_batch, fft_length0, fft_length1, fft_length2);
|
||||
break;
|
||||
case ::xla::FftType::IRFFT:
|
||||
case FftType::IRFFT:
|
||||
EigenFftC2R<FFTRank, EigenDevice>(
|
||||
device, static_cast<float*>(out), static_cast<complex64*>(operand),
|
||||
input_batch, fft_length0, fft_length1, fft_length2);
|
||||
@ -205,7 +213,7 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand,
|
||||
|
||||
template <typename EigenDevice>
|
||||
void EigenFftImpl(const EigenDevice& device, void* out, void* operand,
|
||||
int32 fft_type, int32 fft_rank, int64 input_batch,
|
||||
FftType fft_type, int32 fft_rank, int64 input_batch,
|
||||
int64 fft_length0, int64 fft_length1, int64 fft_length2) {
|
||||
switch (fft_rank) {
|
||||
case 1:
|
||||
|
31
tensorflow/compiler/xla/service/cpu/runtime_fft_test.cc
Normal file
31
tensorflow/compiler/xla/service/cpu/runtime_fft_test.cc
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
TEST(FftTypeTest, MatchesProto) {
|
||||
EXPECT_EQ(::xla::FftType_ARRAYSIZE, 4);
|
||||
EXPECT_EQ(::tensorflow::xla::kFftTypeArraySize, 4);
|
||||
EXPECT_EQ(::xla::FftType::FFT,
|
||||
static_cast<::tensorflow::int32>(::tensorflow::xla::FftType::FFT));
|
||||
EXPECT_EQ(::xla::FftType::IFFT,
|
||||
static_cast<::tensorflow::int32>(::tensorflow::xla::FftType::IFFT));
|
||||
EXPECT_EQ(::xla::FftType::RFFT,
|
||||
static_cast<::tensorflow::int32>(::tensorflow::xla::FftType::RFFT));
|
||||
EXPECT_EQ(::xla::FftType::IRFFT, static_cast<::tensorflow::int32>(
|
||||
::tensorflow::xla::FftType::IRFFT));
|
||||
}
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
@ -26,7 +26,8 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft(
|
||||
const void* run_options_ptr, void* out, void* operand, int32 fft_type,
|
||||
int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1,
|
||||
int64 fft_length2) {
|
||||
tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type,
|
||||
tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand,
|
||||
static_cast<tensorflow::xla::FftType>(fft_type),
|
||||
fft_rank, input_batch, fft_length0, fft_length1,
|
||||
fft_length2);
|
||||
}
|
||||
|
@ -44,15 +44,13 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStream(
|
||||
return result;
|
||||
}
|
||||
|
||||
static ShapeTree<MaybeOwningDeviceMemory> MakeMaybeOwningDeviceMemoryTree(
|
||||
static ExecutionInput MakeMaybeOwningDeviceMemoryTree(
|
||||
const ShapedBuffer& shaped_buffer) {
|
||||
ShapeTree<MaybeOwningDeviceMemory> result(shaped_buffer.on_device_shape());
|
||||
auto in_it = shaped_buffer.buffers().begin();
|
||||
auto out_it = result.begin();
|
||||
for (; in_it != shaped_buffer.buffers().end(); ++in_it, ++out_it) {
|
||||
DCHECK(out_it != result.end());
|
||||
out_it->second = MaybeOwningDeviceMemory(in_it->second);
|
||||
}
|
||||
ExecutionInput result(shaped_buffer.on_device_shape());
|
||||
shaped_buffer.buffers().ForEachElement(
|
||||
[&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) {
|
||||
result.SetBuffer(index, MaybeOwningDeviceMemory(mem));
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -60,7 +58,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
absl::Span<const ShapedBuffer* const> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> args(arguments.size());
|
||||
std::vector<ExecutionInput> args(arguments.size());
|
||||
auto out_it = args.begin();
|
||||
for (const ShapedBuffer* arg : arguments) {
|
||||
*out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg);
|
||||
@ -73,7 +71,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStream(
|
||||
|
||||
StatusOr<ExecutionOutput> Executable::ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
StatusOr<ExecutionOutput> result = ExecuteAsyncOnStream(
|
||||
run_options, std::move(arguments), hlo_execution_profile);
|
||||
@ -238,7 +236,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteAsyncOnStreamWrapper(
|
||||
|
||||
StatusOr<ExecutionOutput> Executable::ExecuteAsyncOnStreamWrapper(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments) {
|
||||
std::vector<ExecutionInput> arguments) {
|
||||
auto state = ExecuteWrapperBeforeExecution(*this, run_options);
|
||||
StatusOr<ExecutionOutput> return_value = ExecuteAsyncOnStream(
|
||||
run_options, std::move(arguments), state.profile_ptr.get());
|
||||
|
@ -42,18 +42,75 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// TODO(b/150633678): Both the ExecutionInput and ExecutionOutput need to be
|
||||
// revisited, with the execute APIs taking data structure which can better model
|
||||
// shareable buffers.
|
||||
class ExecutionInput {
|
||||
public:
|
||||
ExecutionInput() = default;
|
||||
explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {}
|
||||
explicit ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers)
|
||||
: buffers_(std::move(buffers)) {}
|
||||
ExecutionInput(ShapeTree<MaybeOwningDeviceMemory> buffers,
|
||||
std::vector<ShapeIndex> owner_held_indices)
|
||||
: buffers_(std::move(buffers)),
|
||||
unowned_indices_(std::move(owner_held_indices)) {}
|
||||
ExecutionInput(ExecutionInput&&) = default;
|
||||
|
||||
~ExecutionInput() {
|
||||
for (auto& index : unowned_indices_) {
|
||||
auto buffer = buffers_.mutable_element(index)->Release();
|
||||
if (buffer) {
|
||||
buffer->Release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ExecutionInput& operator=(ExecutionInput&&) = default;
|
||||
|
||||
const Shape& shape() const { return buffers_.shape(); }
|
||||
|
||||
void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) {
|
||||
*buffers_.mutable_element(index) = std::move(buffer);
|
||||
}
|
||||
|
||||
void SetUnownedBuffer(const ShapeIndex& index,
|
||||
MaybeOwningDeviceMemory buffer) {
|
||||
*buffers_.mutable_element(index) = std::move(buffer);
|
||||
unowned_indices_.push_back(index);
|
||||
}
|
||||
|
||||
const ShapeTree<MaybeOwningDeviceMemory>& Buffers() const { return buffers_; }
|
||||
|
||||
ShapeTree<MaybeOwningDeviceMemory>* MutableBuffers() { return &buffers_; }
|
||||
|
||||
MaybeOwningDeviceMemory* MutableBuffer(const ShapeIndex& index) {
|
||||
return buffers_.mutable_element(index);
|
||||
}
|
||||
|
||||
const MaybeOwningDeviceMemory& Buffer(const ShapeIndex& index) const {
|
||||
return buffers_.element(index);
|
||||
}
|
||||
|
||||
private:
|
||||
ShapeTree<MaybeOwningDeviceMemory> buffers_;
|
||||
std::vector<ShapeIndex> unowned_indices_;
|
||||
};
|
||||
|
||||
// ExecutionOutput encapsulates the output buffers of a execution and the
|
||||
// leftover buffers to be released by the caller.
|
||||
class ExecutionOutput {
|
||||
public:
|
||||
explicit ExecutionOutput(ScopedShapedBuffer result)
|
||||
: result_(std::move(result)) {}
|
||||
ExecutionOutput(ScopedShapedBuffer result,
|
||||
std::vector<se::OwningDeviceMemory> to_be_released,
|
||||
std::vector<ShapeIndex> aliased_indices,
|
||||
se::OwningDeviceMemory output_shape_table)
|
||||
std::vector<se::OwningDeviceMemory> to_be_released)
|
||||
: result_(std::move(result)),
|
||||
to_be_released_(std::move(to_be_released)),
|
||||
aliased_indices_(std::move(aliased_indices)),
|
||||
output_shape_table_(std::move(output_shape_table)) {}
|
||||
to_be_released_(std::move(to_be_released)) {}
|
||||
ExecutionOutput(Shape on_host_shape, Shape on_device_shape,
|
||||
se::DeviceMemoryAllocator* allocator, int device_ordinal)
|
||||
: result_(std::move(on_host_shape), std::move(on_device_shape), allocator,
|
||||
device_ordinal) {}
|
||||
ExecutionOutput(ExecutionOutput&&) = default;
|
||||
ExecutionOutput& operator=(ExecutionOutput&&) = default;
|
||||
|
||||
@ -66,6 +123,18 @@ class ExecutionOutput {
|
||||
}
|
||||
}
|
||||
|
||||
void AddAliasedIndex(ShapeIndex index) {
|
||||
aliased_indices_.push_back(std::move(index));
|
||||
}
|
||||
|
||||
void AddToBeReleased(se::OwningDeviceMemory mem) {
|
||||
to_be_released_.push_back(std::move(mem));
|
||||
}
|
||||
|
||||
void SetOutputShapeTable(se::OwningDeviceMemory output_shape_table) {
|
||||
output_shape_table_ = std::move(output_shape_table);
|
||||
}
|
||||
|
||||
// Should be called once it is known that the execute operation succeeded,
|
||||
// before returning the ExecutionOutput to the caller.
|
||||
ExecutionOutput& Commit() {
|
||||
@ -75,6 +144,8 @@ class ExecutionOutput {
|
||||
|
||||
const ScopedShapedBuffer& Result() const { return result_; }
|
||||
|
||||
ScopedShapedBuffer* MutableResult() { return &result_; }
|
||||
|
||||
const se::OwningDeviceMemory& ShapeTable() const {
|
||||
return output_shape_table_;
|
||||
}
|
||||
@ -169,12 +240,12 @@ class Executable {
|
||||
// complete.
|
||||
StatusOr<ExecutionOutput> ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
|
||||
virtual StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) = 0;
|
||||
|
||||
// Same as ExecuteOnStream(), but runs this executable on multiple
|
||||
@ -208,7 +279,7 @@ class Executable {
|
||||
|
||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStreamWrapper(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments);
|
||||
std::vector<ExecutionInput> arguments);
|
||||
|
||||
const HloProfilePrinterData& hlo_profile_printer_data() const {
|
||||
CHECK(hlo_profiling_enabled());
|
||||
|
0
tensorflow/compiler/xla/service/gpu/BUILD
Executable file → Normal file
0
tensorflow/compiler/xla/service/gpu/BUILD
Executable file → Normal file
@ -328,7 +328,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) {
|
||||
|
||||
StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
XLA_SCOPED_LOGGING_TIMER(absl::StrCat("GpuExecutable::ExecuteAsyncOnStream(",
|
||||
module().name(), ")"));
|
||||
@ -367,7 +367,7 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
||||
auto param_no = allocation.parameter_number();
|
||||
se::DeviceMemoryBase buffer =
|
||||
arguments[param_no]
|
||||
.element(allocation.param_shape_index())
|
||||
.Buffer(allocation.param_shape_index())
|
||||
.AsDeviceMemoryBase();
|
||||
|
||||
// All top-level buffers and sub-buffers must have an explicit, non-null
|
||||
@ -458,16 +458,15 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
||||
TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result));
|
||||
|
||||
std::vector<se::OwningDeviceMemory> buffers_to_free;
|
||||
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
|
||||
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
|
||||
auto maybe_owning_buffer = buffer.second.Release();
|
||||
for (auto& argument : arguments) {
|
||||
for (auto& index_buffer : *argument.MutableBuffers()) {
|
||||
auto maybe_owning_buffer = index_buffer.second.Release();
|
||||
if (maybe_owning_buffer) {
|
||||
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free),
|
||||
{}, {});
|
||||
return ExecutionOutput(std::move(shaped_buffer), std::move(buffers_to_free));
|
||||
}
|
||||
|
||||
const InstructionValueSet& GpuExecutable::GetRootValueSet() const {
|
||||
|
@ -84,7 +84,7 @@ class GpuExecutable : public Executable {
|
||||
// doesn't match the compute capability passed to this object's constructor.
|
||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) override;
|
||||
|
||||
std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {
|
||||
|
@ -520,7 +520,7 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)),
|
||||
b->CreateICmpEQ(
|
||||
b->getInt32(0),
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)));
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b)));
|
||||
}
|
||||
|
||||
bool AreFusedReductionOutputsConsistent(
|
||||
|
@ -2156,9 +2156,9 @@ void IrEmitterUnnested::EmitPrologueForReduction(
|
||||
reduce_inst->shape().element_type(), module_);
|
||||
llvm::Type* buffer_type = [&] {
|
||||
if (reduction_info->IsRowReduction()) {
|
||||
// Allocate __shared__ cache[num_partial_results][num_threads].
|
||||
// Allocate __shared__ cache[num_partial_results][kWarpSize].
|
||||
return llvm::ArrayType::get(
|
||||
llvm::ArrayType::get(primitive_type, num_threads_x),
|
||||
llvm::ArrayType::get(primitive_type, kWarpSize),
|
||||
num_partial_results);
|
||||
} else {
|
||||
// Allocate __shared__
|
||||
|
@ -815,6 +815,30 @@ ENTRY %primitive_computation_svd.38 (constant_5: f32[3,29,29], fusion.3: pred[3]
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
||||
}
|
||||
|
||||
TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) {
|
||||
const char *const kHloString = R"(
|
||||
HloModule RowReduce
|
||||
|
||||
Sum {
|
||||
x.1 = f32[] parameter(0)
|
||||
y.1 = f32[] parameter(1)
|
||||
ROOT add.1 = f32[] add(x.1, y.1)
|
||||
}
|
||||
|
||||
ENTRY reduce.1 {
|
||||
parameter = f32[1048576] parameter(0)
|
||||
init_value = f32[] constant(0)
|
||||
ROOT reduce = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum
|
||||
}
|
||||
)";
|
||||
auto hlo_module = ParseAndReturnVerifiedModule(kHloString).ValueOrDie();
|
||||
auto expected_ir = R"(
|
||||
; CHECK: shared_cache_{{[0-9]*}} = private addrspace({{[0-9]*}}) global [1 x [32 x float]]
|
||||
)";
|
||||
CompileAndVerifyIr(std::move(hlo_module), expected_ir,
|
||||
/*match_optimized_ir=*/true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -412,6 +412,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
absl::flat_hash_map<const HloValue*, BufferInterval> buffer_intervals_;
|
||||
Result result_;
|
||||
BufferIntervalCompare buffer_interval_compare_;
|
||||
BufferIntervalTree interval_tree_;
|
||||
|
||||
private:
|
||||
int64 alignment_;
|
||||
@ -420,8 +421,6 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
||||
// Alloc or Free call.
|
||||
int64 current_time_ = 0;
|
||||
|
||||
BufferIntervalTree interval_tree_;
|
||||
|
||||
// Returns all transitive colocated buffers of this buffer interval. I.e., If
|
||||
// a buffer A is colocated with B and B is colocated with C, this function
|
||||
// returns all three of them.
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -109,8 +110,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
|
||||
TEST_F(HloConstantFoldingTest, Concatenate) {
|
||||
const struct TestConfig {
|
||||
int concat_dimension;
|
||||
absl::Span<const int64> dimensions;
|
||||
absl::Span<const int64> concat_sizes;
|
||||
std::vector<int64> dimensions;
|
||||
std::vector<int64> concat_sizes;
|
||||
} test_configs[] = {
|
||||
{1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
|
||||
{3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
|
||||
|
0
tensorflow/compiler/xla/service/hlo_instruction.cc
Executable file → Normal file
0
tensorflow/compiler/xla/service/hlo_instruction.cc
Executable file → Normal file
0
tensorflow/compiler/xla/service/hlo_instructions.h
Executable file → Normal file
0
tensorflow/compiler/xla/service/hlo_instructions.h
Executable file → Normal file
@ -717,7 +717,10 @@ HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
|
||||
|
||||
uint64 HloModule::Hash() const {
|
||||
uint64 result = entry_computation_layout().Hash();
|
||||
for (auto* computation : MakeComputationPostOrder()) {
|
||||
// Use MakeComputationSortedByContent() instead of MakeComputationPostOrder()
|
||||
// because naming may affect the order of MakeComputationPostOrder() but not
|
||||
// MakeComputationSortedByContent().
|
||||
for (auto* computation : MakeComputationSortedByContent()) {
|
||||
for (auto* instruction : computation->MakeInstructionPostOrder()) {
|
||||
result = tensorflow::Hash64Combine(result, instruction->Hash());
|
||||
}
|
||||
|
@ -196,6 +196,10 @@ class HloModule {
|
||||
// computation B, then A will appear after B in the sort.
|
||||
std::vector<HloComputation*> MakeComputationPostOrder() const;
|
||||
|
||||
// Same as MakeComputationPostOrder() but sorting the computations by their
|
||||
// contents.
|
||||
std::vector<HloComputation*> MakeComputationSortedByContent() const;
|
||||
|
||||
// Gets the computations in this module which aren't for fusion nodes.
|
||||
//
|
||||
// Postcondition: All computations in the returned list have
|
||||
@ -346,10 +350,6 @@ class HloModule {
|
||||
std::unique_ptr<HloComputation> computation, bool is_entry,
|
||||
bool uniquify_identifiers);
|
||||
|
||||
// Same as MakeComputationPostOrder() but sorting the computations by their
|
||||
// contents.
|
||||
std::vector<HloComputation*> MakeComputationSortedByContent() const;
|
||||
|
||||
string name_;
|
||||
HloModuleConfig config_;
|
||||
HloComputation* entry_computation_ = nullptr;
|
||||
|
@ -104,9 +104,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
||||
|
||||
VLOG(1) << "Run backend " << hlo_module->name();
|
||||
|
||||
// Typically you would visit the HLO graph, building up a compiled equivalent
|
||||
// In this case we are using an HloEvaluator at execution time, so we don't
|
||||
// need to compile anything
|
||||
TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
|
||||
DynamicDimensionInference::Run(hlo_module.get()));
|
||||
|
||||
auto evaluator = absl::make_unique<HloEvaluator>();
|
||||
evaluator->set_use_fast_path(
|
||||
@ -115,8 +114,9 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
|
||||
|
||||
// Create executable from only the Hlo module.
|
||||
std::unique_ptr<Executable> executable =
|
||||
absl::make_unique<InterpreterExecutable>(std::move(hlo_module),
|
||||
std::move(evaluator));
|
||||
absl::make_unique<InterpreterExecutable>(
|
||||
std::move(hlo_module), std::move(evaluator),
|
||||
std::move(dynamic_dimension_inference));
|
||||
|
||||
return std::move(executable);
|
||||
}
|
||||
|
@ -39,16 +39,23 @@ namespace interpreter {
|
||||
|
||||
InterpreterExecutable::InterpreterExecutable(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
std::unique_ptr<HloEvaluator> evaluator)
|
||||
std::unique_ptr<HloEvaluator> evaluator,
|
||||
absl::optional<DynamicDimensionInference> dynamic_dymension_inference)
|
||||
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
|
||||
/*hlo_profile_index_map=*/nullptr),
|
||||
evaluator_(std::move(evaluator)) {}
|
||||
evaluator_(std::move(evaluator)),
|
||||
dynamic_dimension_inference_(std::move(dynamic_dymension_inference)) {
|
||||
if (dynamic_dimension_inference_.has_value()) {
|
||||
evaluator_->set_dynamic_dimension_inference(
|
||||
&dynamic_dimension_inference_.value());
|
||||
}
|
||||
}
|
||||
|
||||
InterpreterExecutable::~InterpreterExecutable() {}
|
||||
|
||||
StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
se::StreamExecutor* executor = stream->parent();
|
||||
@ -58,13 +65,14 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
|
||||
// TransferManager methods below.
|
||||
std::vector<ShapedBuffer> argument_buffers;
|
||||
argument_buffers.reserve(arguments.size());
|
||||
for (const ShapeTree<MaybeOwningDeviceMemory>& arg : arguments) {
|
||||
argument_buffers.push_back(ShapedBuffer(arg.shape(), arg.shape(),
|
||||
for (auto& argument : arguments) {
|
||||
const ShapeTree<MaybeOwningDeviceMemory>& buffers = argument.Buffers();
|
||||
argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(),
|
||||
/*platform=*/nullptr,
|
||||
/*device_ordinal=*/0));
|
||||
auto in_it = arg.begin();
|
||||
auto in_it = buffers.begin();
|
||||
auto out_it = argument_buffers.back().buffers().begin();
|
||||
for (; in_it != arg.end(); ++in_it, ++out_it) {
|
||||
for (; in_it != buffers.end(); ++in_it, ++out_it) {
|
||||
out_it->second = in_it->second.AsDeviceMemoryBase();
|
||||
}
|
||||
}
|
||||
@ -120,12 +128,13 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
|
||||
}
|
||||
|
||||
// Transform the result literal back into a ShapedBuffer.
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
result_literal.shape(), run_options->allocator(),
|
||||
executor->device_ordinal()));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
|
||||
run_options->stream(), result_literal, result));
|
||||
run_options->stream(), result_literal, result_buffers));
|
||||
ExecutionOutput result(std::move(result_buffers));
|
||||
|
||||
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
|
||||
|
||||
@ -134,17 +143,15 @@ StatusOr<ExecutionOutput> InterpreterExecutable::ExecuteAsyncOnStream(
|
||||
const double nanoseconds = (end_micros - start_micros) * 1000.0;
|
||||
profile->set_compute_time_ns(std::max(nanoseconds, 1.0));
|
||||
}
|
||||
|
||||
std::vector<se::OwningDeviceMemory> buffers_to_free;
|
||||
for (ShapeTree<MaybeOwningDeviceMemory>& argument : arguments) {
|
||||
for (std::pair<ShapeIndex, MaybeOwningDeviceMemory>& buffer : argument) {
|
||||
auto maybe_owning_buffer = buffer.second.Release();
|
||||
for (auto& argument : arguments) {
|
||||
for (auto& index_buffer : *argument.MutableBuffers()) {
|
||||
auto maybe_owning_buffer = index_buffer.second.Release();
|
||||
if (maybe_owning_buffer) {
|
||||
buffers_to_free.push_back(std::move(*maybe_owning_buffer));
|
||||
result.AddToBeReleased(std::move(*maybe_owning_buffer));
|
||||
}
|
||||
}
|
||||
}
|
||||
return ExecutionOutput(std::move(result), std::move(buffers_to_free), {}, {});
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) {
|
||||
|
@ -42,13 +42,15 @@ namespace interpreter {
|
||||
// buffer allocation. Refer to interpreter/README.md for more.
|
||||
class InterpreterExecutable : public Executable {
|
||||
public:
|
||||
InterpreterExecutable(std::unique_ptr<HloModule> hlo_module,
|
||||
std::unique_ptr<HloEvaluator> evaluator);
|
||||
InterpreterExecutable(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
std::unique_ptr<HloEvaluator> evaluator,
|
||||
absl::optional<DynamicDimensionInference> dynamic_dymension_inference);
|
||||
~InterpreterExecutable() override;
|
||||
|
||||
StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> arguments,
|
||||
std::vector<ExecutionInput> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) override
|
||||
LOCKS_EXCLUDED(evaluator_lock_);
|
||||
|
||||
@ -60,6 +62,7 @@ class InterpreterExecutable : public Executable {
|
||||
mutable tensorflow::mutex evaluator_lock_;
|
||||
|
||||
private:
|
||||
absl::optional<DynamicDimensionInference> dynamic_dimension_inference_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable);
|
||||
};
|
||||
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
@ -380,6 +382,13 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!ConsumeFuel("memory_space_assignment", [&] {
|
||||
return absl::StrCat("Ran out of fuel at buffer: ",
|
||||
colocated_intervals[0]->buffer->ToShortString());
|
||||
})) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const HloComputation* defining_computation =
|
||||
colocated_intervals[0]->buffer->defining_instruction()->parent();
|
||||
MemorySpaceAssignment::Allocation* aliased_allocation = nullptr;
|
||||
@ -458,8 +467,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
// If the allocation finding failed (e.g., due to running out of
|
||||
// asynchronous copies), then fall back to allocating the buffer
|
||||
// entirely in the default memory.
|
||||
pending_chunks_.clear();
|
||||
pending_async_copies_.clear();
|
||||
UncommitPendingChunks();
|
||||
allocation_sequence->clear();
|
||||
break;
|
||||
}
|
||||
@ -478,7 +486,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
}
|
||||
}
|
||||
|
||||
CommitPendingChunks();
|
||||
pending_chunks_.clear();
|
||||
pending_async_copies_.clear();
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
@ -510,6 +519,12 @@ void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
|
||||
it_and_inserted.first->start_time == copy.start_time);
|
||||
}
|
||||
|
||||
void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
|
||||
auto copy_it = ranges_.find(copy);
|
||||
CHECK(copy_it != ranges_.end());
|
||||
ranges_.erase(copy_it);
|
||||
}
|
||||
|
||||
bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time,
|
||||
int64 end_time) const {
|
||||
// We allow identical start and end times. It is enough to check for just the
|
||||
@ -620,32 +635,31 @@ bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
|
||||
return false;
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::CommitPendingChunks() {
|
||||
void AlternateMemoryBestFitHeap::UncommitPendingChunks() {
|
||||
for (auto interval_and_chunk : pending_chunks_) {
|
||||
VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-"
|
||||
<< interval_and_chunk.first.end << " : ["
|
||||
<< interval_and_chunk.second.chunk.offset << ", "
|
||||
<< interval_and_chunk.second.chunk.size << "]";
|
||||
CommitChunk(interval_and_chunk.first, interval_and_chunk.second);
|
||||
const BufferInterval& interval = interval_and_chunk.first;
|
||||
const Chunk& chunk = interval_and_chunk.second.chunk;
|
||||
interval_tree_.Remove(interval.start, interval.end, chunk);
|
||||
}
|
||||
for (const auto& interval : pending_async_copies_) {
|
||||
async_copy_interval_tree_.Remove(interval.start_time, interval.end_time,
|
||||
kDummyChunk);
|
||||
if (interval.destination == MemorySpace::kAlternate) {
|
||||
async_copy_ordering_.RemoveCopy(interval);
|
||||
}
|
||||
}
|
||||
pending_chunks_.clear();
|
||||
// Also add the pending async copies to the interval tree.
|
||||
for (const auto& interval : pending_async_copies_) {
|
||||
if (options_.max_outstanding_async_copies >= 0) {
|
||||
async_copy_interval_tree_.Add(interval.start_time, interval.end_time,
|
||||
kDummyChunk);
|
||||
}
|
||||
if (interval.destination == MemorySpace::kAlternate) {
|
||||
async_copy_ordering_.AddCopy(interval);
|
||||
}
|
||||
}
|
||||
pending_async_copies_.clear();
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AddToPendingChunks(
|
||||
const BufferInterval& buffer_interval,
|
||||
const ChunkCandidate& chunk_candidate) {
|
||||
VLOG(3) << "Committing chunk: " << buffer_interval.start << "-"
|
||||
<< buffer_interval.end << " : [" << chunk_candidate.chunk.offset
|
||||
<< ", " << chunk_candidate.chunk.size << "]";
|
||||
pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
|
||||
CommitChunk(buffer_interval, chunk_candidate);
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::RequiredInDefaultMemory(const HloValue* buffer,
|
||||
@ -772,6 +786,10 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy(
|
||||
// Register the additional async copy with the interval tree to keep track of
|
||||
// the limit at any given time.
|
||||
pending_async_copies_.push_back({start_time, end_time, memory_space});
|
||||
async_copy_interval_tree_.Add(start_time, end_time, kDummyChunk);
|
||||
if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) {
|
||||
async_copy_ordering_.AddCopy(pending_async_copies_.back());
|
||||
}
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
||||
@ -780,17 +798,11 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
||||
return false;
|
||||
}
|
||||
|
||||
// Count both the asynchronous copies in the interval tree as well as the
|
||||
// pending asynchronous copies belonging to this buffer.
|
||||
// Count the asynchronous copies in the interval tree for the given interval.
|
||||
int64 num_async_copies =
|
||||
async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
|
||||
.size();
|
||||
|
||||
for (const auto& interval : pending_async_copies_) {
|
||||
if (interval.start_time > start_time && interval.end_time < end_time) {
|
||||
num_async_copies++;
|
||||
}
|
||||
}
|
||||
// Add one because we are checking if adding an additional asynchronous copy
|
||||
// would violate the limit.
|
||||
return num_async_copies + 1 > options_.max_outstanding_async_copies;
|
||||
@ -798,19 +810,7 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
|
||||
|
||||
bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(
|
||||
int64 start_time, int64 end_time) const {
|
||||
if (async_copy_ordering_.ViolatesOrdering(start_time, end_time)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Also check pending async copies.
|
||||
for (const auto& async_copy : pending_async_copies_) {
|
||||
if (async_copy.destination == MemorySpace::kAlternate &&
|
||||
async_copy.start_time <= end_time &&
|
||||
start_time <= async_copy.end_time) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
|
||||
@ -1048,7 +1048,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
|
||||
|
||||
ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval);
|
||||
// Check if the new heap size fits within limits.
|
||||
if (chunk_candidate.heap_size < available_heap_size()) {
|
||||
if (chunk_candidate.heap_size <= available_heap_size()) {
|
||||
VLOG(3) << "Move the buffer to alternate memory at "
|
||||
<< alternate_mem_interval.start
|
||||
<< ". Offset = " << chunk_candidate.chunk.offset
|
||||
|
@ -580,6 +580,9 @@ class AsynchronousCopyOrdering {
|
||||
// Adds an asynchronous copy.
|
||||
void AddCopy(const AsynchronousCopy& copy);
|
||||
|
||||
// Removes an asynchronous copy. CHECKs that it is removed.
|
||||
void RemoveCopy(const AsynchronousCopy& copy);
|
||||
|
||||
// Returns true if the addition of an asynchronous copy in the the given time
|
||||
// interval would violate the asynchronous copy ordering. E.g., consider the
|
||||
// following scenario:
|
||||
@ -735,11 +738,15 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
int64 end_time, int64 copy_done_schedule_before_time,
|
||||
MemorySpaceAssignment::AllocationSequence* allocations);
|
||||
|
||||
// These methods are used for delaying committing the chunk candidate until
|
||||
// the entire live range of the buffer has been considered.
|
||||
// This method is used for committing the chunk candidate but adding it to
|
||||
// pending_chunks_ so that we can "uncommit" them in case we need to roll back
|
||||
// this allocation sequence.
|
||||
void AddToPendingChunks(const BufferInterval& buffer_interval,
|
||||
const ChunkCandidate& chunk_candidate);
|
||||
void CommitPendingChunks();
|
||||
// If we need to remove the allocations for this allocation sequence, this
|
||||
// removes pending chunks and asynchronous copies in the respective pending
|
||||
// buffers from the interval trees.
|
||||
void UncommitPendingChunks();
|
||||
|
||||
// Returns the available heap size in the alternate memory.
|
||||
int64 available_heap_size() const {
|
||||
|
@ -2827,6 +2827,100 @@ TEST_P(MemorySpaceAssignmentTest,
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) {
|
||||
// Tests a memory corruption bug where the allocated chunk overlaps with a
|
||||
// pending chunk. To test this, we provide a new buffer interval compare where
|
||||
// we prioritize the allocation of sine, cosine, and tanh to create the
|
||||
// situation:
|
||||
//
|
||||
// Max memory
|
||||
// -------------------------------------------
|
||||
// +------------+
|
||||
// | b |
|
||||
// +------------+
|
||||
// +-------+
|
||||
// | |
|
||||
// | |
|
||||
// | a |
|
||||
// | | +------------+
|
||||
// | | | n |
|
||||
// +-------+ +------------+
|
||||
// -------------------------------------------
|
||||
// Min memory time ->
|
||||
//
|
||||
//
|
||||
// Then allocating for buffer d, we have these two prefetch buffers
|
||||
// overlapping:
|
||||
//
|
||||
// Max memory
|
||||
// -------------------------------------------
|
||||
// +------------+ +----------+
|
||||
// | b | | prefetch |
|
||||
// +------------+ | for o |
|
||||
// +-------+ +---------+ |
|
||||
// | | | | | |
|
||||
// | | | | | |
|
||||
// | a | | +----|-----+
|
||||
// | | | prefetch| +------------+
|
||||
// | | | for m | | n |
|
||||
// +-------+ +---------+ +------------+
|
||||
// -------------------------------------------
|
||||
// Min memory time ->
|
||||
//
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule bug, is_scheduled=true
|
||||
|
||||
ENTRY %Entry {
|
||||
%param0 = f32[8,3] parameter(0)
|
||||
%param1 = f32[2,4] parameter(1)
|
||||
%a = f32[8,3] sine(%param0)
|
||||
%b = f32[2,4] cosine(%param1)
|
||||
%d = f32[8,3] tanh(%a)
|
||||
%c = f32[8,3] negate(%a)
|
||||
%e = f32[2,4] negate(%b)
|
||||
%f = f32[2,4] negate(%e)
|
||||
%g = f32[2,4] negate(%f)
|
||||
%h = f32[2,4] negate(%g)
|
||||
%i = f32[2,4] negate(%h)
|
||||
%j = f32[2,4] negate(%i)
|
||||
%k = f32[2,4] negate(%j)
|
||||
%l = f32[2,4] negate(%k)
|
||||
%m = f32[8,3] negate(%d)
|
||||
%n = f32[2,4] sine(%l)
|
||||
%o = f32[8,3] negate(%d)
|
||||
%p = f32[2,4] negate(%n)
|
||||
%q = f32[8,3] negate(%m)
|
||||
ROOT %tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(%p, %q, %o)
|
||||
}
|
||||
)";
|
||||
|
||||
MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
|
||||
[](const MemorySpaceAssignment::BufferInterval& a,
|
||||
const MemorySpaceAssignment::BufferInterval& b) {
|
||||
auto get_opcode_priority = [](const HloOpcode& opcode) {
|
||||
switch (opcode) {
|
||||
case HloOpcode::kSin:
|
||||
return 0;
|
||||
case HloOpcode::kCos:
|
||||
return 1;
|
||||
case HloOpcode::kTanh:
|
||||
return 2;
|
||||
default:
|
||||
return 3;
|
||||
}
|
||||
};
|
||||
|
||||
return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
|
||||
get_opcode_priority(b.buffer->defining_instruction()->opcode());
|
||||
};
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
|
||||
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
|
||||
buffer_interval_compare, &prefetch_interval_picker);
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, Determinism) {
|
||||
// Run memory space assignment a few times to make sure every time it compiles
|
||||
// to the same thing.
|
||||
|
@ -2513,6 +2513,18 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "get_dimension_size_test",
|
||||
srcs = ["get_dimension_size_test.cc"],
|
||||
deps = [
|
||||
":hlo_test_base",
|
||||
":test_macros_header",
|
||||
":xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "triangular_solve_test",
|
||||
srcs = ["triangular_solve_test.cc"],
|
||||
|
@ -96,8 +96,8 @@ class BufferDonationTest : public HloTestBase {
|
||||
memory_allocator.get());
|
||||
});
|
||||
|
||||
std::vector<ShapeTree<MaybeOwningDeviceMemory>> args;
|
||||
args.emplace_back(std::move(owned_buffers));
|
||||
std::vector<ExecutionInput> args;
|
||||
args.emplace_back(ExecutionInput(std::move(owned_buffers)));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
ExecutionOutput output,
|
||||
|
48
tensorflow/compiler/xla/tests/get_dimension_size_test.cc
Normal file
48
tensorflow/compiler/xla/tests/get_dimension_size_test.cc
Normal file
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class GetDimensionSizeTest : public HloTestBase {};
|
||||
|
||||
// Test that the interpreter can correctly compute get_dimension_size.
|
||||
TEST_F(GetDimensionSizeTest, DoIt) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule a_inference_call_110__.55
|
||||
|
||||
ENTRY %a_inference_call_110__.55 (arg0.1: f32[1,8], arg1.2: f32[8], arg2.3: f32[8]) -> s32[] {
|
||||
%constant.37 = f32[] constant(1e-12)
|
||||
%broadcast.38 = f32[1,1]{1,0} broadcast(f32[] %constant.37), dimensions={}
|
||||
%arg0.1 = f32[1,8]{1,0} parameter(0), parameter_replication={false}
|
||||
%reshape.4 = f32[1,8]{1,0} reshape(f32[1,8]{1,0} %arg0.1)
|
||||
%convert.5 = f32[1,8]{1,0} convert(f32[1,8]{1,0} %reshape.4)
|
||||
%constant.6 = f32[] constant(0)
|
||||
%convert.7 = f32[] convert(f32[] %constant.6)
|
||||
ROOT %get-dimension-size.13 = s32[] get-dimension-size(f32[1,8]{1,0} %convert.5), dimensions={1}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr));
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0.01, 0.01}));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
@ -39,8 +40,8 @@ static std::array<bool, 1> use_bfloat16_params{false};
|
||||
#endif
|
||||
|
||||
struct ReverseSpec {
|
||||
absl::Span<const int64> input_dims;
|
||||
absl::Span<const int64> reversal;
|
||||
std::vector<int64> input_dims;
|
||||
std::vector<int64> reversal;
|
||||
bool use_bfloat16;
|
||||
|
||||
string ToTestCaseName() const {
|
||||
|
@ -72,6 +72,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user