Merge branch 'master' into master
This commit is contained in:
commit
fccd345ab3
@ -24,9 +24,21 @@ cc_library(
|
|||||||
"//tensorflow:windows": get_win_copts(),
|
"//tensorflow:windows": get_win_copts(),
|
||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
|
":gcs_helper",
|
||||||
|
"//tensorflow/c:env",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "gcs_helper",
|
||||||
|
srcs = ["gcs_helper.cc"],
|
||||||
|
hdrs = ["gcs_helper.h"],
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/c:env",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -15,9 +15,13 @@ limitations under the License.
|
|||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "google/cloud/storage/client.h"
|
#include "google/cloud/storage/client.h"
|
||||||
|
#include "tensorflow/c/env.h"
|
||||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||||
|
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
|
||||||
// Implementation of a filesystem for GCS environments.
|
// Implementation of a filesystem for GCS environments.
|
||||||
@ -86,6 +90,20 @@ namespace tf_random_access_file {
|
|||||||
// SECTION 2. Implementation for `TF_WritableFile`
|
// SECTION 2. Implementation for `TF_WritableFile`
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
namespace tf_writable_file {
|
namespace tf_writable_file {
|
||||||
|
typedef struct GCSFile {
|
||||||
|
const char* bucket;
|
||||||
|
const char* object;
|
||||||
|
gcs::Client* gcs_client; // not owned
|
||||||
|
TempFile outfile;
|
||||||
|
bool sync_need;
|
||||||
|
} GCSFile;
|
||||||
|
|
||||||
|
static void Cleanup(TF_WritableFile* file) {
|
||||||
|
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||||
|
plugin_memory_free(const_cast<char*>(gcs_file->bucket));
|
||||||
|
plugin_memory_free(const_cast<char*>(gcs_file->object));
|
||||||
|
delete gcs_file;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(vnvo2409): Implement later
|
// TODO(vnvo2409): Implement later
|
||||||
|
|
||||||
@ -119,6 +137,20 @@ static void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
|||||||
|
|
||||||
// TODO(vnvo2409): Implement later
|
// TODO(vnvo2409): Implement later
|
||||||
|
|
||||||
|
static void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||||
|
TF_WritableFile* file, TF_Status* status) {
|
||||||
|
char* bucket;
|
||||||
|
char* object;
|
||||||
|
ParseGCSPath(path, false, &bucket, &object, status);
|
||||||
|
if (TF_GetCode(status) != TF_OK) return;
|
||||||
|
|
||||||
|
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
|
||||||
|
TempFile outfile(TF_GetTempFileName(""), std::ios::binary | std::ios::out);
|
||||||
|
file->plugin_file = new tf_writable_file::GCSFile(
|
||||||
|
{bucket, object, gcs_client, std::move(outfile), true});
|
||||||
|
TF_SetStatus(status, TF_OK, "");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tf_gcs_filesystem
|
} // namespace tf_gcs_filesystem
|
||||||
|
|
||||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||||
@ -126,9 +158,14 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
|||||||
TF_SetFilesystemVersionMetadata(ops);
|
TF_SetFilesystemVersionMetadata(ops);
|
||||||
ops->scheme = strdup(uri);
|
ops->scheme = strdup(uri);
|
||||||
|
|
||||||
|
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||||
|
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||||
|
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||||
|
|
||||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||||
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
|
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
|
||||||
|
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||||
|
@ -0,0 +1,34 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
TempFile::TempFile(const char* temp_file_name, std::ios::openmode mode)
|
||||||
|
: std::fstream(temp_file_name, mode), name_(temp_file_name) {}
|
||||||
|
|
||||||
|
TempFile::TempFile(TempFile&& rhs)
|
||||||
|
: std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {}
|
||||||
|
|
||||||
|
TempFile::~TempFile() {
|
||||||
|
std::fstream::close();
|
||||||
|
std::remove(name_.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string TempFile::getName() const { return name_; }
|
@ -0,0 +1,33 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
||||||
|
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
class TempFile : public std::fstream {
|
||||||
|
public:
|
||||||
|
// We should specify openmode each time we call TempFile.
|
||||||
|
TempFile(const char* temp_file_name, std::ios::openmode mode);
|
||||||
|
TempFile(TempFile&& rhs);
|
||||||
|
~TempFile() override;
|
||||||
|
const std::string getName() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
@ -314,7 +314,6 @@ tf_cc_test(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "tensorflow_lite_legalize_tf",
|
name = "tensorflow_lite_legalize_tf",
|
||||||
srcs = [
|
srcs = [
|
||||||
"transforms/device_index_selector.cc",
|
|
||||||
"transforms/dilated_conv.cc",
|
"transforms/dilated_conv.cc",
|
||||||
"transforms/generated_legalize_tf.inc",
|
"transforms/generated_legalize_tf.inc",
|
||||||
"transforms/generated_lower_static_tensor_list.inc",
|
"transforms/generated_lower_static_tensor_list.inc",
|
||||||
|
@ -953,14 +953,14 @@ in the batch dimensions and broadcasting.
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32]>:$x,
|
TFL_TensorOf<[F32, QI8]>:$x,
|
||||||
TFL_TensorOf<[F32]>:$y,
|
TFL_TensorOf<[F32, QI8]>:$y,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32]>:$output
|
TFL_TensorOf<[F32, QI8]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
@ -63,7 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
|||||||
standard_pipeline_options.enable_inliner = false;
|
standard_pipeline_options.enable_inliner = false;
|
||||||
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
||||||
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
||||||
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
|
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
|
||||||
|
|
||||||
if (pass_config.shape_inference) {
|
if (pass_config.shape_inference) {
|
||||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||||
|
@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
|||||||
// Verifies runtime constraints.
|
// Verifies runtime constraints.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||||
|
|
||||||
// Creates function pass to select device index/fold tf.DeviceIndex.
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -475,6 +475,7 @@ cc_library(
|
|||||||
"transforms/cluster_outlining.cc",
|
"transforms/cluster_outlining.cc",
|
||||||
"transforms/collection_ops_util.cc",
|
"transforms/collection_ops_util.cc",
|
||||||
"transforms/decompose_resource_ops_pass.cc",
|
"transforms/decompose_resource_ops_pass.cc",
|
||||||
|
"transforms/device_index_selector.cc",
|
||||||
"transforms/einsum.cc",
|
"transforms/einsum.cc",
|
||||||
"transforms/executor_island_coarsening.cc",
|
"transforms/executor_island_coarsening.cc",
|
||||||
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
||||||
|
@ -164,6 +164,81 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
|
|||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> {
|
||||||
|
let summary = "Adjust the contrast of one or more images.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
|
||||||
|
interpreted as `[height, width, channels]`. The other dimensions only
|
||||||
|
represent a collection of images, such as `[batch, height, width, channels].`
|
||||||
|
|
||||||
|
Contrast is adjusted independently for each channel of each image.
|
||||||
|
|
||||||
|
For each channel, the Op first computes the mean of the image pixels in the
|
||||||
|
channel and then adjusts each component of each pixel to
|
||||||
|
`(x - mean) * contrast_factor + mean`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TensorOf<[F16, F32]>:$images,
|
||||||
|
F32Tensor:$contrast_factor
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[F16, F32]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_AdjustHueOp : TF_Op<"AdjustHue", [NoSideEffect]> {
|
||||||
|
let summary = "Adjust the hue of one or more images.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
`images` is a tensor of at least 3 dimensions. The last dimension is
|
||||||
|
interpreted as channels, and must be three.
|
||||||
|
|
||||||
|
The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||||
|
colors are first mapped into HSV. A delta is then applied all the hue values,
|
||||||
|
and then remapped back to RGB colorspace.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TensorOf<[F16, F32]>:$images,
|
||||||
|
F32Tensor:$delta
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[F16, F32]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
|
def TF_AdjustSaturationOp : TF_Op<"AdjustSaturation", [NoSideEffect]> {
|
||||||
|
let summary = "Adjust the saturation of one or more images.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
`images` is a tensor of at least 3 dimensions. The last dimension is
|
||||||
|
interpreted as channels, and must be three.
|
||||||
|
|
||||||
|
The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||||
|
colors are first mapped into HSV. A scale is then applied all the saturation
|
||||||
|
values, and then remapped back to RGB colorspace.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TensorOf<[F16, F32]>:$images,
|
||||||
|
F32Tensor:$scale
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TensorOf<[F16, F32]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
|
def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Computes the "logical and" of elements across dimensions of a tensor.
|
Computes the "logical and" of elements across dimensions of a tensor.
|
||||||
@ -3866,6 +3941,28 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True]
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_HSVToRGBOp : TF_Op<"HSVToRGB", [NoSideEffect]> {
|
||||||
|
let summary = "Convert one or more images from HSV to RGB.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Outputs a tensor of the same shape as the `images` tensor, containing the RGB
|
||||||
|
value of the pixels. The output is only well defined if the value in `images`
|
||||||
|
are in `[0,1]`.
|
||||||
|
|
||||||
|
See `rgb_to_hsv` for a description of the HSV encoding.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_FpTensor:$images
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_FpTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_HashTableV2Op : TF_Op<"HashTableV2", []> {
|
def TF_HashTableV2Op : TF_Op<"HashTableV2", []> {
|
||||||
let summary = "Creates a non-initialized hash table.";
|
let summary = "Creates a non-initialized hash table.";
|
||||||
|
|
||||||
@ -5962,11 +6059,11 @@ I.e., \\(y = -x\\).
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x
|
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y
|
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y
|
||||||
);
|
);
|
||||||
|
|
||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
@ -6733,6 +6830,41 @@ the dimension is padded with zeros.
|
|||||||
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
|
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_RGBToHSVOp : TF_Op<"RGBToHSV", [NoSideEffect]> {
|
||||||
|
let summary = "Converts one or more images from RGB to HSV.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Outputs a tensor of the same shape as the `images` tensor, containing the HSV
|
||||||
|
value of the pixels. The output is only well defined if the value in `images`
|
||||||
|
are in `[0,1]`.
|
||||||
|
|
||||||
|
`output[..., 0]` contains hue, `output[..., 1]` contains saturation, and
|
||||||
|
`output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0
|
||||||
|
corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue.
|
||||||
|
|
||||||
|
Usage Example:
|
||||||
|
|
||||||
|
>>> blue_image = tf.stack([
|
||||||
|
... tf.zeros([5,5]),
|
||||||
|
... tf.zeros([5,5]),
|
||||||
|
... tf.ones([5,5])],
|
||||||
|
... axis=-1)
|
||||||
|
>>> blue_hsv_image = tf.image.rgb_to_hsv(blue_image)
|
||||||
|
>>> blue_hsv_image[0,0].numpy()
|
||||||
|
array([0.6666667, 1. , 1. ], dtype=float32)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_FpTensor:$images
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_FpTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>,
|
def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||||
WithBroadcastableBinOpBuilder {
|
WithBroadcastableBinOpBuilder {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
@ -7230,6 +7362,27 @@ Input images can be of different types but output images are always float.
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_ResizeBilinearGradOp : TF_Op<"ResizeBilinearGrad", [NoSideEffect]> {
|
||||||
|
let summary = "Computes the gradient of bilinear interpolation.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
F32Tensor:$grads,
|
||||||
|
TF_FpTensor:$original_image,
|
||||||
|
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_FpTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_ResizeNearestNeighborOp : TF_Op<"ResizeNearestNeighbor", [NoSideEffect]> {
|
def TF_ResizeNearestNeighborOp : TF_Op<"ResizeNearestNeighbor", [NoSideEffect]> {
|
||||||
let summary = [{
|
let summary = [{
|
||||||
Resize `images` to `size` using nearest neighbor interpolation.
|
Resize `images` to `size` using nearest neighbor interpolation.
|
||||||
|
@ -21,11 +21,11 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace TFL {
|
namespace TF {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the
|
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the
|
||||||
@ -55,8 +55,8 @@ void DeviceIndexSelector::runOnOperation() {
|
|||||||
// Convert all the DeviceIndex ops to constant values.
|
// Convert all the DeviceIndex ops to constant values.
|
||||||
func.getBody().walk([](TF::DeviceIndexOp op) {
|
func.getBody().walk([](TF::DeviceIndexOp op) {
|
||||||
// This just selects the default in all cases where DeviceIndex feeds into
|
// This just selects the default in all cases where DeviceIndex feeds into
|
||||||
// tf.Case. This could be enhanced based on explicit TFLite specification or
|
// tf.Case. This could be enhanced to have some sort of policy in the
|
||||||
// TAC in future.
|
// future.
|
||||||
OpBuilder b(op);
|
OpBuilder b(op);
|
||||||
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
|
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
|
||||||
int index = op.device_names().size();
|
int index = op.device_names().size();
|
||||||
@ -79,7 +79,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<DeviceIndexSelector> pass(
|
static PassRegistration<DeviceIndexSelector> pass(
|
||||||
"tfl-device-index-selector", "Fold tf.DeviceIndex to constant");
|
"tf-device-index-selector", "Fold tf.DeviceIndex to constant");
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
@ -52,11 +52,6 @@ struct FusedKernelMatcherPass
|
|||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns an op's name with the dialect prefix stripped off.
|
|
||||||
StringRef GetOpNameWithoutDialect(Operation *op) {
|
|
||||||
return op->getName().getStringRef().split(".").second;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsActivationFunction(Operation *op) {
|
bool IsActivationFunction(Operation *op) {
|
||||||
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
|
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
|
||||||
}
|
}
|
||||||
@ -128,8 +123,8 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
|
SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
|
||||||
SmallVector<Attribute, 2> fused_ops{
|
SmallVector<Attribute, 2> fused_ops{StringAttr::get(
|
||||||
StringAttr::get(GetOpNameWithoutDialect(bias_add), context)};
|
bias_add.getOperation()->getName().stripDialect(), context)};
|
||||||
|
|
||||||
// BiasAdd may or may not feed into an activation function.
|
// BiasAdd may or may not feed into an activation function.
|
||||||
auto activation = GetActivation(bias_add);
|
auto activation = GetActivation(bias_add);
|
||||||
@ -143,7 +138,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
|
|||||||
if (fuse_activation) {
|
if (fuse_activation) {
|
||||||
locations.push_back(activation->getLoc());
|
locations.push_back(activation->getLoc());
|
||||||
fused_ops.push_back(
|
fused_ops.push_back(
|
||||||
StringAttr::get(GetOpNameWithoutDialect(activation), context));
|
StringAttr::get(activation->getName().stripDialect(), context));
|
||||||
result_type = activation->getResultTypes().front();
|
result_type = activation->getResultTypes().front();
|
||||||
} else {
|
} else {
|
||||||
result_type = bias_add.getResult().getType();
|
result_type = bias_add.getResult().getType();
|
||||||
|
@ -147,6 +147,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
|
|||||||
// generally used beyond exporting to runtimes that supports these ops. In the
|
// generally used beyond exporting to runtimes that supports these ops. In the
|
||||||
// future these fusions may be codegen'd automatically.
|
// future these fusions may be codegen'd automatically.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
|
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
|
||||||
|
|
||||||
|
// Creates function pass to select device index/fold tf.DeviceIndex.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
|
|
||||||
namespace tf_executor {
|
namespace tf_executor {
|
||||||
|
@ -106,53 +106,6 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
|
|||||||
return GetI64ElementsAttr(slice_limits, builder);
|
return GetI64ElementsAttr(slice_limits, builder);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the padding value of the given position. If padding_attr is a
|
|
||||||
// nullptr, returns 0.
|
|
||||||
static int64_t GetPaddingValue(DenseIntElementsAttr padding_attr,
|
|
||||||
ArrayRef<uint64_t> index) {
|
|
||||||
if (!padding_attr) return 0;
|
|
||||||
return padding_attr.getValue<int64_t>(index);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool IsOnlyPaddingSpatialDims(Value lhs,
|
|
||||||
ConvDimensionNumbers dimension_numbers,
|
|
||||||
DenseIntElementsAttr edge_padding_low,
|
|
||||||
DenseIntElementsAttr edge_padding_high) {
|
|
||||||
const int64_t batch_dim = dimension_numbers.input_batch_dimension().getInt();
|
|
||||||
const int64_t feature_dim =
|
|
||||||
dimension_numbers.input_feature_dimension().getInt();
|
|
||||||
if (edge_padding_low.getValue<int64_t>(batch_dim) ||
|
|
||||||
edge_padding_high.getValue<int64_t>(batch_dim))
|
|
||||||
return false;
|
|
||||||
if (edge_padding_low.getValue<int64_t>(feature_dim) ||
|
|
||||||
edge_padding_high.getValue<int64_t>(feature_dim))
|
|
||||||
return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
DenseIntElementsAttr BuildConvPaddingAttrs(
|
|
||||||
DenseIntElementsAttr edge_padding_low,
|
|
||||||
DenseIntElementsAttr edge_padding_high, DenseIntElementsAttr padding_attr,
|
|
||||||
ConvDimensionNumbers dimension_numbers, Builder* builder) {
|
|
||||||
SmallVector<int64_t, 4> padding_low, padding_high;
|
|
||||||
for (const auto& dim : dimension_numbers.input_spatial_dimensions()) {
|
|
||||||
unsigned i = dim.getZExtValue();
|
|
||||||
padding_low.push_back(edge_padding_low.getValue<int64_t>(i));
|
|
||||||
padding_high.push_back(edge_padding_high.getValue<int64_t>(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
int rank = padding_low.size();
|
|
||||||
SmallVector<int64_t, 8> padding;
|
|
||||||
for (unsigned i = 0, e = rank; i < e; ++i) {
|
|
||||||
padding.push_back(GetPaddingValue(padding_attr, {i, 0}) + padding_low[i]);
|
|
||||||
padding.push_back(GetPaddingValue(padding_attr, {i, 1}) + padding_high[i]);
|
|
||||||
}
|
|
||||||
// padding_attr.getType() doesn't work because it is an optional attribute,
|
|
||||||
// which can be a nullptr.
|
|
||||||
auto type = RankedTensorType::get({rank, 2}, builder->getIntegerType(64));
|
|
||||||
return DenseIntElementsAttr::get(type, padding);
|
|
||||||
}
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/generated_canonicalize.inc"
|
#include "tensorflow/compiler/mlir/xla/transforms/generated_canonicalize.inc"
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -2153,14 +2106,5 @@ LogicalResult deriveShapeFromFirstOperand(
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// ConvOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
void ConvOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
|
||||||
MLIRContext* context) {
|
|
||||||
results.insert<FoldPadIntoConv>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla_hlo
|
} // namespace xla_hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@ -929,8 +929,6 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {
|
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {
|
||||||
|
@ -132,6 +132,22 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
|
|||||||
return MakeXlaOp(op);
|
return MakeXlaOp(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
||||||
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
|
const Shape& shape, const string& opaque,
|
||||||
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
||||||
|
if (operand_shapes_with_layout.has_value())
|
||||||
|
return Unimplemented(
|
||||||
|
"CustomCall doesn't support operands shapes with layout");
|
||||||
|
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||||
|
shape, builder_));
|
||||||
|
auto op = builder_.create<mlir::xla_hlo::CustomCallOp>(
|
||||||
|
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
||||||
|
/*has_side_effect=*/builder_.getBoolAttr(false),
|
||||||
|
builder_.getStringAttr(opaque));
|
||||||
|
return MakeXlaOp(op);
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
|
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
|
||||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||||
const XlaComputation& computation,
|
const XlaComputation& computation,
|
||||||
|
@ -124,6 +124,12 @@ class MlirHloBuilder : public XlaBuilder {
|
|||||||
FftType fft_type,
|
FftType fft_type,
|
||||||
absl::Span<const int64> fft_length) override;
|
absl::Span<const int64> fft_length) override;
|
||||||
|
|
||||||
|
StatusOr<XlaOp> CustomCallInternal(const string& call_target_name,
|
||||||
|
absl::Span<const XlaOp> operands,
|
||||||
|
const Shape& shape, const string& opaque,
|
||||||
|
absl::optional<absl::Span<const Shape>>
|
||||||
|
operand_shapes_with_layout) override;
|
||||||
|
|
||||||
StatusOr<XlaOp> ReduceInternal(
|
StatusOr<XlaOp> ReduceInternal(
|
||||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||||
const XlaComputation& computation,
|
const XlaComputation& computation,
|
||||||
|
@ -415,71 +415,6 @@ func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
|
|||||||
return %0 : tensor<1x4xf32>
|
return %0 : tensor<1x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @fold_pad_into_conv_f32
|
|
||||||
func @fold_pad_into_conv_f32(%arg0 : tensor<1x32x32x3xf32>,
|
|
||||||
%arg1 : tensor<7x7x3x64xf32>)
|
|
||||||
-> tensor<1x16x16x64xf32> {
|
|
||||||
// CHECK-NOT: xla_hlo.pad
|
|
||||||
// CHECK: xla_hlo.convolution
|
|
||||||
// CHECK-SAME: padding = dense<3> : tensor<2x2xi64>
|
|
||||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
|
||||||
%1 = "xla_hlo.pad"(%arg0, %0) {
|
|
||||||
edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>,
|
|
||||||
edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>,
|
|
||||||
interior_padding = dense<0> : tensor<4xi64>
|
|
||||||
} : (tensor<1x32x32x3xf32>, tensor<f32>) -> tensor<1x38x38x3xf32>
|
|
||||||
%2 = "xla_hlo.convolution"(%1, %arg1) {
|
|
||||||
batch_group_count = 1 : i64,
|
|
||||||
dimension_numbers = {
|
|
||||||
input_batch_dimension = 0 : i64,
|
|
||||||
input_feature_dimension = 3 : i64,
|
|
||||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
|
||||||
kernel_input_feature_dimension = 2 : i64,
|
|
||||||
kernel_output_feature_dimension = 3 : i64,
|
|
||||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
|
||||||
output_batch_dimension = 0 : i64,
|
|
||||||
output_feature_dimension = 3 : i64,
|
|
||||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
|
||||||
},
|
|
||||||
feature_group_count = 1 : i64,
|
|
||||||
padding = dense<0> : tensor<2x2xi64>,
|
|
||||||
window_strides = dense<2> : tensor<2xi64>
|
|
||||||
} : (tensor<1x38x38x3xf32>, tensor<7x7x3x64xf32>) -> tensor<1x16x16x64xf32>
|
|
||||||
return %2 : tensor<1x16x16x64xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @fold_pad_into_conv_i32
|
|
||||||
func @fold_pad_into_conv_i32(%arg0 : tensor<1x32x32x3xi32>,
|
|
||||||
%arg1 : tensor<7x7x3x64xi32>)
|
|
||||||
-> tensor<1x16x16x64xi32> {
|
|
||||||
// CHECK-NOT: xla_hlo.pad
|
|
||||||
// CHECK: xla_hlo.convolution
|
|
||||||
// CHECK-SAME: padding = dense<3> : tensor<2x2xi64>
|
|
||||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
|
||||||
%1 = "xla_hlo.pad"(%arg0, %0) {
|
|
||||||
edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>,
|
|
||||||
edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>,
|
|
||||||
interior_padding = dense<0> : tensor<4xi64>
|
|
||||||
} : (tensor<1x32x32x3xi32>, tensor<i32>) -> tensor<1x38x38x3xi32>
|
|
||||||
%2 = "xla_hlo.convolution"(%1, %arg1) {
|
|
||||||
batch_group_count = 1 : i64,
|
|
||||||
dimension_numbers = {
|
|
||||||
input_batch_dimension = 0 : i64,
|
|
||||||
input_feature_dimension = 3 : i64,
|
|
||||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
|
||||||
kernel_input_feature_dimension = 2 : i64,
|
|
||||||
kernel_output_feature_dimension = 3 : i64,
|
|
||||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
|
||||||
output_batch_dimension = 0 : i64,
|
|
||||||
output_feature_dimension = 3 : i64,
|
|
||||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
|
||||||
},
|
|
||||||
feature_group_count = 1 : i64,
|
|
||||||
window_strides = dense<2> : tensor<2xi64>
|
|
||||||
} : (tensor<1x38x38x3xi32>, tensor<7x7x3x64xi32>) -> tensor<1x16x16x64xi32>
|
|
||||||
return %2 : tensor<1x16x16x64xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
|
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
|
||||||
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
|
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
|
||||||
// CHECK: xla_hlo.reshape
|
// CHECK: xla_hlo.reshape
|
||||||
|
@ -35,7 +35,7 @@ func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor
|
|||||||
// CHECK-LABEL: unranked_operand
|
// CHECK-LABEL: unranked_operand
|
||||||
func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: tf.Abs
|
// CHECK: tf.Abs
|
||||||
// expected-remark@+1 {{lowering requires static shaped operands}}
|
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
|
||||||
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
|
||||||
return %0 : tensor<*xf32>
|
return %0 : tensor<*xf32>
|
||||||
@ -44,12 +44,20 @@ func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
|||||||
// CHECK-LABEL: dynamic_operand
|
// CHECK-LABEL: dynamic_operand
|
||||||
func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
// CHECK: tf.Abs
|
// CHECK: tf.Abs
|
||||||
// expected-remark@+1 {{lowering requires static shaped operands}}
|
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
|
||||||
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
|
||||||
return %0 : tensor<?xf32>
|
return %0 : tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: tuple_type
|
||||||
|
func @tuple_type(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||||
|
// Verifies that the pass can handle operands of non-tensor type like tuple
|
||||||
|
// from non TensorFlow ops.
|
||||||
|
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: unsupported_dtype
|
// CHECK-LABEL: unsupported_dtype
|
||||||
func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
|
func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
|
||||||
// CHECK: tf.AddN
|
// CHECK: tf.AddN
|
||||||
|
@ -28,54 +28,3 @@ def UnaryEinsumToEinsum : Pat<
|
|||||||
(HLO_UnaryEinsumOp $operand, $equation),
|
(HLO_UnaryEinsumOp $operand, $equation),
|
||||||
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
|
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
|
||||||
$operand, (UnaryToBinaryEinsumEq $equation))>;
|
$operand, (UnaryToBinaryEinsumEq $equation))>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Conv op patterns.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def IsZero : Attr<CPred<
|
|
||||||
"($_self.isa<DenseFPElementsAttr>() &&"
|
|
||||||
"$_self.cast<DenseFPElementsAttr>().isSplat() &&"
|
|
||||||
"$_self.cast<DenseFPElementsAttr>().getSplatValue<FloatAttr>()"
|
|
||||||
".getValue().isZero()) ||"
|
|
||||||
"($_self.isa<DenseIntElementsAttr>() &&"
|
|
||||||
"$_self.cast<DenseIntElementsAttr>().isSplat() &&"
|
|
||||||
"$_self.cast<DenseIntElementsAttr>().getSplatValue<IntegerAttr>()"
|
|
||||||
".getInt() == 0)">>;
|
|
||||||
|
|
||||||
def IsOnlyPaddingSpatialDims
|
|
||||||
: Constraint<CPred<"IsOnlyPaddingSpatialDims($0, $1, $2, $3)">>;
|
|
||||||
|
|
||||||
def BuildConvPaddingAttrs : NativeCodeCall<
|
|
||||||
"BuildConvPaddingAttrs($0, $1, $2, $3, &$_builder)">;
|
|
||||||
|
|
||||||
def FoldPadIntoConv : Pat<
|
|
||||||
(HLO_ConvOp
|
|
||||||
(HLO_PadOp $lhs,
|
|
||||||
(HLO_ConstOp IsZero:$padding_value),
|
|
||||||
$edge_padding_low,
|
|
||||||
$edge_padding_high,
|
|
||||||
IsZero:$interior_padding),
|
|
||||||
$rhs,
|
|
||||||
$window_strides,
|
|
||||||
$padding,
|
|
||||||
$lhs_dilation,
|
|
||||||
$rhs_dilation,
|
|
||||||
$dimension_numbers,
|
|
||||||
$feature_group_count,
|
|
||||||
$batch_group_count,
|
|
||||||
$precision_config),
|
|
||||||
(HLO_ConvOp
|
|
||||||
$lhs,
|
|
||||||
$rhs,
|
|
||||||
$window_strides,
|
|
||||||
(BuildConvPaddingAttrs $edge_padding_low, $edge_padding_high, $padding,
|
|
||||||
$dimension_numbers),
|
|
||||||
$lhs_dilation,
|
|
||||||
$rhs_dilation,
|
|
||||||
$dimension_numbers,
|
|
||||||
$feature_group_count,
|
|
||||||
$batch_group_count,
|
|
||||||
$precision_config),
|
|
||||||
[(IsOnlyPaddingSpatialDims $lhs, $dimension_numbers, $edge_padding_low,
|
|
||||||
$edge_padding_high)]>;
|
|
||||||
|
@ -389,10 +389,13 @@ struct HloLegalizeToLhlo
|
|||||||
target.addLegalOp<ModuleTerminatorOp>();
|
target.addLegalOp<ModuleTerminatorOp>();
|
||||||
target.addLegalOp<TensorFromElementsOp>();
|
target.addLegalOp<TensorFromElementsOp>();
|
||||||
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
||||||
|
|
||||||
|
BufferAssignmentTypeConverter converter;
|
||||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||||
auto inputs = op.getType().getInputs();
|
auto inputs = op.getType().getInputs();
|
||||||
return std::all_of(inputs.begin(), inputs.end(),
|
return llvm::all_of(inputs,
|
||||||
[](Type input) { return input.isa<MemRefType>(); });
|
[](Type input) { return input.isa<MemRefType>(); }) &&
|
||||||
|
converter.isLegal(&op.getBody());
|
||||||
});
|
});
|
||||||
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
|
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
|
||||||
return std::all_of(returnOp.operand_type_begin(),
|
return std::all_of(returnOp.operand_type_begin(),
|
||||||
@ -401,8 +404,7 @@ struct HloLegalizeToLhlo
|
|||||||
});
|
});
|
||||||
|
|
||||||
auto module = getOperation();
|
auto module = getOperation();
|
||||||
BufferAssignmentTypeConverter converter;
|
module.walk([&](FuncOp func) -> WalkResult {
|
||||||
module.walk([&](FuncOp func) {
|
|
||||||
BufferAssignmentPlacer bufferAssignment(func);
|
BufferAssignmentPlacer bufferAssignment(func);
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
|
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
|
||||||
@ -418,8 +420,7 @@ struct HloLegalizeToLhlo
|
|||||||
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
|
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
|
||||||
&converter, &patterns);
|
&converter, &patterns);
|
||||||
}
|
}
|
||||||
return WalkResult(
|
return applyPartialConversion(func, target, patterns);
|
||||||
applyPartialConversion(func, target, patterns, &converter));
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -463,6 +464,7 @@ void populateHLOToLHLOConversionPattern(
|
|||||||
HloToLhloOpConverter<xla_hlo::RealOp>,
|
HloToLhloOpConverter<xla_hlo::RealOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::RemOp>,
|
HloToLhloOpConverter<xla_hlo::RemOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
|
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::SelectOp>,
|
HloToLhloOpConverter<xla_hlo::SelectOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::SignOp>,
|
HloToLhloOpConverter<xla_hlo::SignOp>,
|
||||||
HloToLhloOpConverter<xla_hlo::SqrtOp>,
|
HloToLhloOpConverter<xla_hlo::SqrtOp>,
|
||||||
|
@ -5238,8 +5238,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
|||||||
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
|
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
|
||||||
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
|
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
|
||||||
DenseSet<Operation *> nonlegalized_ops;
|
DenseSet<Operation *> nonlegalized_ops;
|
||||||
LogicalResult result = applyPartialConversion(
|
LogicalResult result =
|
||||||
op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops);
|
applyPartialConversion(op, target, patterns, &nonlegalized_ops);
|
||||||
// In order to enforce that the conversion result is fully converted,
|
// In order to enforce that the conversion result is fully converted,
|
||||||
// fail if there are any nonlegalized ops in the set.
|
// fail if there are any nonlegalized ops in the set.
|
||||||
if (failed(result) || !nonlegalized_ops.empty()) {
|
if (failed(result) || !nonlegalized_ops.empty()) {
|
||||||
|
@ -88,6 +88,9 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||||||
TypeID::get<TF::AddNOp>(),
|
TypeID::get<TF::AddNOp>(),
|
||||||
TypeID::get<TF::AddV2Op>(),
|
TypeID::get<TF::AddV2Op>(),
|
||||||
TypeID::get<TF::AngleOp>(),
|
TypeID::get<TF::AngleOp>(),
|
||||||
|
TypeID::get<TF::AdjustContrastv2Op>(),
|
||||||
|
TypeID::get<TF::AdjustHueOp>(),
|
||||||
|
TypeID::get<TF::AdjustSaturationOp>(),
|
||||||
TypeID::get<TF::ApproximateEqualOp>(),
|
TypeID::get<TF::ApproximateEqualOp>(),
|
||||||
TypeID::get<TF::ArgMaxOp>(),
|
TypeID::get<TF::ArgMaxOp>(),
|
||||||
TypeID::get<TF::ArgMinOp>(),
|
TypeID::get<TF::ArgMinOp>(),
|
||||||
@ -127,6 +130,7 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||||||
TypeID::get<TF::GatherNdOp>(),
|
TypeID::get<TF::GatherNdOp>(),
|
||||||
TypeID::get<TF::GreaterEqualOp>(),
|
TypeID::get<TF::GreaterEqualOp>(),
|
||||||
TypeID::get<TF::GreaterOp>(),
|
TypeID::get<TF::GreaterOp>(),
|
||||||
|
TypeID::get<TF::HSVToRGBOp>(),
|
||||||
TypeID::get<TF::IFFT2DOp>(),
|
TypeID::get<TF::IFFT2DOp>(),
|
||||||
TypeID::get<TF::IFFT3DOp>(),
|
TypeID::get<TF::IFFT3DOp>(),
|
||||||
TypeID::get<TF::IFFTOp>(),
|
TypeID::get<TF::IFFTOp>(),
|
||||||
@ -157,10 +161,14 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||||||
TypeID::get<TF::PowOp>(),
|
TypeID::get<TF::PowOp>(),
|
||||||
TypeID::get<TF::RFFT2DOp>(),
|
TypeID::get<TF::RFFT2DOp>(),
|
||||||
TypeID::get<TF::RFFT3DOp>(),
|
TypeID::get<TF::RFFT3DOp>(),
|
||||||
|
TypeID::get<TF::RGBToHSVOp>(),
|
||||||
TypeID::get<TF::RealDivOp>(),
|
TypeID::get<TF::RealDivOp>(),
|
||||||
TypeID::get<TF::ReciprocalOp>(),
|
TypeID::get<TF::ReciprocalOp>(),
|
||||||
TypeID::get<TF::ReciprocalGradOp>(),
|
TypeID::get<TF::ReciprocalGradOp>(),
|
||||||
TypeID::get<TF::Relu6GradOp>(),
|
TypeID::get<TF::Relu6GradOp>(),
|
||||||
|
TypeID::get<TF::ResizeBilinearOp>(),
|
||||||
|
TypeID::get<TF::ResizeBilinearGradOp>(),
|
||||||
|
TypeID::get<TF::ResizeNearestNeighborOp>(),
|
||||||
TypeID::get<TF::ReverseSequenceOp>(),
|
TypeID::get<TF::ReverseSequenceOp>(),
|
||||||
TypeID::get<TF::RightShiftOp>(),
|
TypeID::get<TF::RightShiftOp>(),
|
||||||
TypeID::get<TF::RintOp>(),
|
TypeID::get<TF::RintOp>(),
|
||||||
@ -337,9 +345,9 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
|||||||
|
|
||||||
// Only static shaped operands are supported in XLA builders for now.
|
// Only static shaped operands are supported in XLA builders for now.
|
||||||
for (Type ty : op->getOperandTypes()) {
|
for (Type ty : op->getOperandTypes()) {
|
||||||
auto ranked_ty = ty.cast<ShapedType>();
|
auto ranked_ty = ty.dyn_cast<ShapedType>();
|
||||||
if (!ranked_ty.hasStaticShape()) {
|
if (!ranked_ty || !ranked_ty.hasStaticShape()) {
|
||||||
op->emitRemark() << "lowering requires static shaped operands";
|
op->emitRemark() << "lowering requires static shaped tensor operands";
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -177,7 +177,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
|
|||||||
target.addIllegalOp<ReduceOp>();
|
target.addIllegalOp<ReduceOp>();
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
|
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
|
||||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ class TestLhloToLLVMPass
|
|||||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
target.addIllegalDialect<XlaLhloDialect>();
|
target.addIllegalDialect<XlaLhloDialect>();
|
||||||
|
|
||||||
if (failed(applyFullConversion(m, target, patterns, &converter))) {
|
if (failed(applyFullConversion(m, target, patterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -711,7 +711,7 @@ struct LhloLegalizeToParallelLoops
|
|||||||
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
|
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
|
||||||
xla_lhlo::SelectAndScatterOp>();
|
xla_lhlo::SelectAndScatterOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -61,6 +61,7 @@ MAP_HLO_TO_LHLO(MulOp);
|
|||||||
MAP_HLO_TO_LHLO(NegOp);
|
MAP_HLO_TO_LHLO(NegOp);
|
||||||
MAP_HLO_TO_LHLO(RealOp);
|
MAP_HLO_TO_LHLO(RealOp);
|
||||||
MAP_HLO_TO_LHLO(ReduceOp);
|
MAP_HLO_TO_LHLO(ReduceOp);
|
||||||
|
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||||
MAP_HLO_TO_LHLO(RemOp);
|
MAP_HLO_TO_LHLO(RemOp);
|
||||||
MAP_HLO_TO_LHLO(RsqrtOp);
|
MAP_HLO_TO_LHLO(RsqrtOp);
|
||||||
MAP_HLO_TO_LHLO(SelectOp);
|
MAP_HLO_TO_LHLO(SelectOp);
|
||||||
|
@ -867,7 +867,7 @@ struct LhloLegalizeToLinalg
|
|||||||
|
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -882,7 +882,7 @@ struct HloLegalizeToLinalg
|
|||||||
|
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// Test DeviceIndex selector.
|
// Test DeviceIndex selector.
|
||||||
|
|
||||||
// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s
|
// RUN: tf-opt --tf-device-index-selector %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @select
|
// CHECK-LABEL: func @select
|
||||||
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
|
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
|
@ -770,6 +770,7 @@ tf_xla_py_test(
|
|||||||
size = "small",
|
size = "small",
|
||||||
timeout = "long",
|
timeout = "long",
|
||||||
srcs = ["image_ops_test.py"],
|
srcs = ["image_ops_test.py"],
|
||||||
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 10,
|
shard_count = 10,
|
||||||
tags = [
|
tags = [
|
||||||
|
@ -1564,16 +1564,12 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
const Shape& shape, const string& opaque,
|
const Shape& shape, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
HloInstructionProto instr;
|
|
||||||
if (absl::StartsWith(call_target_name, "$")) {
|
if (absl::StartsWith(call_target_name, "$")) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
|
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
|
||||||
"are reserved for internal use.",
|
"are reserved for internal use.",
|
||||||
call_target_name);
|
call_target_name);
|
||||||
}
|
}
|
||||||
*instr.mutable_shape() = shape.ToProto();
|
|
||||||
instr.set_custom_call_target(call_target_name);
|
|
||||||
instr.set_backend_config(opaque);
|
|
||||||
if (operand_shapes_with_layout.has_value()) {
|
if (operand_shapes_with_layout.has_value()) {
|
||||||
if (!LayoutUtil::HasLayout(shape)) {
|
if (!LayoutUtil::HasLayout(shape)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
@ -1586,7 +1582,6 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
"with constrained layout; given %d shapes, expected %d",
|
"with constrained layout; given %d shapes, expected %d",
|
||||||
operand_shapes_with_layout->size(), operands.size());
|
operand_shapes_with_layout->size(), operands.size());
|
||||||
}
|
}
|
||||||
instr.set_constrain_layout(true);
|
|
||||||
int64 operand_num = 0;
|
int64 operand_num = 0;
|
||||||
for (const Shape& operand_shape : *operand_shapes_with_layout) {
|
for (const Shape& operand_shape : *operand_shapes_with_layout) {
|
||||||
if (!LayoutUtil::HasLayout(operand_shape)) {
|
if (!LayoutUtil::HasLayout(operand_shape)) {
|
||||||
@ -1595,14 +1590,31 @@ XlaOp XlaBuilder::CustomCall(
|
|||||||
"constrained layout.",
|
"constrained layout.",
|
||||||
operand_num);
|
operand_num);
|
||||||
}
|
}
|
||||||
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
|
||||||
++operand_num;
|
++operand_num;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
return CustomCallInternal(call_target_name, operands, shape, opaque,
|
||||||
|
operand_shapes_with_layout);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
||||||
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
|
const Shape& shape, const string& opaque,
|
||||||
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
||||||
|
HloInstructionProto instr;
|
||||||
|
*instr.mutable_shape() = shape.ToProto();
|
||||||
|
instr.set_custom_call_target(call_target_name);
|
||||||
|
instr.set_backend_config(opaque);
|
||||||
|
if (operand_shapes_with_layout.has_value()) {
|
||||||
|
instr.set_constrain_layout(true);
|
||||||
|
for (const Shape& operand_shape : *operand_shapes_with_layout) {
|
||||||
|
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::CustomCall(
|
XlaOp XlaBuilder::CustomCall(
|
||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const XlaComputation& computation, const Shape& shape, const string& opaque,
|
const XlaComputation& computation, const Shape& shape, const string& opaque,
|
||||||
|
@ -527,6 +527,14 @@ class XlaBuilder {
|
|||||||
const Shape& shape_with_layout, const string& opaque,
|
const Shape& shape_with_layout, const string& opaque,
|
||||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
|
||||||
|
|
||||||
|
// Internal version of CustomCall without computation that doesn't do op
|
||||||
|
// specific error handling and expects arguments to be legal. CustomCall
|
||||||
|
// method above calls this method after error handling.
|
||||||
|
virtual StatusOr<XlaOp> CustomCallInternal(
|
||||||
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
|
const Shape& shape_with_layout, const string& opaque,
|
||||||
|
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
|
||||||
|
|
||||||
XlaOp CustomCall(
|
XlaOp CustomCall(
|
||||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||||
const XlaComputation& computation, const Shape& shape_with_layout,
|
const XlaComputation& computation, const Shape& shape_with_layout,
|
||||||
|
@ -141,7 +141,9 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||||
"//tensorflow/core:allocator",
|
"//tensorflow/core:allocator",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/profiler/lib:connected_traceme",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"//tensorflow/core/profiler/lib:traceme_encode",
|
||||||
"//tensorflow/stream_executor:event",
|
"//tensorflow/stream_executor:event",
|
||||||
"//tensorflow/stream_executor:stream",
|
"//tensorflow/stream_executor:stream",
|
||||||
"//tensorflow/stream_executor/host:host_platform_id",
|
"//tensorflow/stream_executor/host:host_platform_id",
|
||||||
|
@ -127,11 +127,15 @@ std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
|
|||||||
} else {
|
} else {
|
||||||
std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
|
std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
|
||||||
usage_stream_pool_.pop();
|
usage_stream_pool_.pop();
|
||||||
|
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
|
||||||
|
QCHECK(stream->ok());
|
||||||
return stream;
|
return stream;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
|
void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
|
||||||
|
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
|
||||||
|
QCHECK(stream->ok());
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
usage_stream_pool_.push(std::move(stream));
|
usage_stream_pool_.push(std::move(stream));
|
||||||
}
|
}
|
||||||
|
@ -98,7 +98,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/mem.h"
|
#include "tensorflow/core/platform/mem.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/profiler/lib/connected_traceme.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
#include "tensorflow/core/profiler/lib/traceme_encode.h"
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||||
#include "tensorflow/stream_executor/event.h"
|
#include "tensorflow/stream_executor/event.h"
|
||||||
@ -749,16 +751,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
|||||||
// memory that has already been allocated, and a possible Event
|
// memory that has already been allocated, and a possible Event
|
||||||
// allocation.
|
// allocation.
|
||||||
|
|
||||||
|
se::Stream* h2d_stream = local_device->host_to_device_stream();
|
||||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||||
compact_shape, on_device_shape, client->client()->platform());
|
compact_shape, on_device_shape, client->client()->platform());
|
||||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||||
local_device->host_to_device_stream(), literal, buffer));
|
h2d_stream, literal, buffer));
|
||||||
|
|
||||||
std::shared_ptr<BufferSequencingEvent> event =
|
std::shared_ptr<BufferSequencingEvent> event =
|
||||||
device_buffer->definition_events()[0];
|
device_buffer->definition_events()[0];
|
||||||
TF_CHECK_OK(AddDestinationBufferSynchronization(
|
TF_CHECK_OK(AddDestinationBufferSynchronization(
|
||||||
local_device, std::move(device_buffer), event,
|
local_device, std::move(device_buffer), event, h2d_stream));
|
||||||
local_device->host_to_device_stream()));
|
|
||||||
|
// This can sometimes catch the case where the literal memory has been
|
||||||
|
// freed before the H2D transfer was issued.
|
||||||
|
h2d_stream->RefreshStatus()
|
||||||
|
.IgnoreError(); // Can return error::Unimplemented
|
||||||
|
QCHECK(h2d_stream->ok());
|
||||||
};
|
};
|
||||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||||
return py_buffer;
|
return py_buffer;
|
||||||
@ -1069,13 +1077,17 @@ Status PjRtBuffer::CopyToHostAsync() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral() {
|
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||||
|
const bool discard_cached_copy) {
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
|
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
|
||||||
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
||||||
std::shared_ptr<HostValue> host_value;
|
std::shared_ptr<HostValue> host_value;
|
||||||
{
|
{
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
host_value = host_value_;
|
host_value = host_value_;
|
||||||
|
if (discard_cached_copy) {
|
||||||
|
host_value_ = nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (host_value == nullptr) {
|
if (host_value == nullptr) {
|
||||||
return InvalidArgument("ToLiteral called on invalid buffer");
|
return InvalidArgument("ToLiteral called on invalid buffer");
|
||||||
@ -1429,10 +1441,9 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
|||||||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||||
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers) const {
|
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers) const {
|
||||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
int device_ordinal = device->local_device_state()->device_ordinal();
|
||||||
tensorflow::profiler::TraceMe traceme([&] {
|
tensorflow::profiler::TraceMeConsumer activity(
|
||||||
return absl::StrCat("LocalExecutable::Execute#run_id=", run_id.ToInt(),
|
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
||||||
"#");
|
run_id.ToInt());
|
||||||
});
|
|
||||||
VLOG(3) << "Replica " << replica << ", partition " << partition
|
VLOG(3) << "Replica " << replica << ", partition " << partition
|
||||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||||
|
|
||||||
@ -1721,10 +1732,9 @@ PjRtExecutable::ExecuteOnLocalDevices(
|
|||||||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||||
const ExecuteOptions& options) const {
|
const ExecuteOptions& options) const {
|
||||||
RunId run_id;
|
RunId run_id;
|
||||||
tensorflow::profiler::TraceMe traceme([&] {
|
tensorflow::profiler::TraceMeProducer activity(
|
||||||
return absl::StrCat(
|
"LocalExecutable::ExecuteOnLocalDevices",
|
||||||
"LocalExecutable::ExecuteOnLocalDevices#run_id=", run_id.ToInt(), "#");
|
tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
|
||||||
});
|
|
||||||
|
|
||||||
const int num_local_devices = local_devices_.size();
|
const int num_local_devices = local_devices_.size();
|
||||||
|
|
||||||
|
@ -478,8 +478,12 @@ class PjRtBuffer {
|
|||||||
|
|
||||||
// Returns the buffer's value as an XLA Literal. If the value has previously
|
// Returns the buffer's value as an XLA Literal. If the value has previously
|
||||||
// been prefetched to the host, then returns the prefetched version, otherwise
|
// been prefetched to the host, then returns the prefetched version, otherwise
|
||||||
// copies the buffer to the host. Blocks until the value is ready.
|
// copies the buffer to the host. Blocks until the value is ready. If
|
||||||
StatusOr<std::shared_ptr<Literal>> ToLiteral();
|
// `discard_cached_copy` is true then buffer will no longer keep hold of a
|
||||||
|
// cached copy of the literal (i.e. The reference to the host value will be
|
||||||
|
// removed.)
|
||||||
|
StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
||||||
|
bool discard_cached_copy = false);
|
||||||
|
|
||||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||||
// the transfer to complete. The value can be retrieved by a later call to
|
// the transfer to complete. The value can be retrieved by a later call to
|
||||||
|
@ -106,7 +106,6 @@ class BranchVisitor {
|
|||||||
boundaries_.emplace_back(operand, i, inst);
|
boundaries_.emplace_back(operand, i, inst);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
worklist_.push_back(operand);
|
worklist_.push_back(operand);
|
||||||
visited_.insert(operand);
|
visited_.insert(operand);
|
||||||
}
|
}
|
||||||
@ -197,6 +196,7 @@ bool WorthHoisting(HloInstruction* instruction) {
|
|||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kDivide:
|
case HloOpcode::kDivide:
|
||||||
case HloOpcode::kTuple:
|
case HloOpcode::kTuple:
|
||||||
|
case HloOpcode::kSqrt:
|
||||||
case HloOpcode::kGetTupleElement:
|
case HloOpcode::kGetTupleElement:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
@ -206,10 +206,11 @@ bool WorthHoisting(HloInstruction* instruction) {
|
|||||||
|
|
||||||
// Compare if the instructions to be visited at each branches are identical.
|
// Compare if the instructions to be visited at each branches are identical.
|
||||||
bool InstructionWithinBranchIdentical(
|
bool InstructionWithinBranchIdentical(
|
||||||
const std::vector<HloInstruction*>& instructions, bool is_layout_senstive) {
|
const std::vector<HloInstruction*>& instructions,
|
||||||
|
bool is_layout_sensitive) {
|
||||||
// Identical includes the shape of each operands are equal.
|
// Identical includes the shape of each operands are equal.
|
||||||
auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
|
auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
|
||||||
bool eq_operands = is_layout_senstive
|
bool eq_operands = is_layout_sensitive
|
||||||
? ShapeUtil::Equal(a->shape(), b->shape())
|
? ShapeUtil::Equal(a->shape(), b->shape())
|
||||||
: ShapeUtil::Compatible(a->shape(), b->shape());
|
: ShapeUtil::Compatible(a->shape(), b->shape());
|
||||||
return eq_operands;
|
return eq_operands;
|
||||||
@ -233,7 +234,7 @@ bool InstructionWithinBranchIdentical(
|
|||||||
auto old_channel_id = instruction->channel_id();
|
auto old_channel_id = instruction->channel_id();
|
||||||
instruction->set_channel_id(instructions[0]->channel_id());
|
instruction->set_channel_id(instructions[0]->channel_id());
|
||||||
bool eq_instructions = instructions[0]->Identical(
|
bool eq_instructions = instructions[0]->Identical(
|
||||||
*instruction, eq_operand, eq_computations, is_layout_senstive);
|
*instruction, eq_operand, eq_computations, is_layout_sensitive);
|
||||||
instruction->set_channel_id(old_channel_id);
|
instruction->set_channel_id(old_channel_id);
|
||||||
return eq_instructions;
|
return eq_instructions;
|
||||||
});
|
});
|
||||||
@ -243,7 +244,7 @@ bool InstructionWithinBranchIdentical(
|
|||||||
[&](HloInstruction* instruction) {
|
[&](HloInstruction* instruction) {
|
||||||
return instructions[0]->Identical(
|
return instructions[0]->Identical(
|
||||||
*instruction, eq_operand, eq_computations,
|
*instruction, eq_operand, eq_computations,
|
||||||
is_layout_senstive);
|
is_layout_sensitive);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -354,12 +355,228 @@ Status RemoveInstructionFromComputation(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Identify converts to be hoisted/rematerialized out of the branch
|
||||||
|
// computations.
|
||||||
|
absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
|
||||||
|
int branch_count,
|
||||||
|
HloInstruction* conditional,
|
||||||
|
bool is_layout_sensitive) {
|
||||||
|
absl::flat_hash_set<int64> kspecial_convert;
|
||||||
|
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||||
|
++operand_num) {
|
||||||
|
if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bool replica = true;
|
||||||
|
HloInstruction* kspecial_convert_candidate =
|
||||||
|
old_root->mutable_operand(operand_num);
|
||||||
|
// Check whether an identical candidate appears in other branches
|
||||||
|
for (int others = 1; others < branch_count; ++others) {
|
||||||
|
HloInstruction* others_root =
|
||||||
|
conditional->branch_computation(others)->root_instruction();
|
||||||
|
bool eq_shape =
|
||||||
|
is_layout_sensitive
|
||||||
|
? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
|
||||||
|
kspecial_convert_candidate->shape())
|
||||||
|
: ShapeUtil::Compatible(
|
||||||
|
others_root->operand(operand_num)->shape(),
|
||||||
|
kspecial_convert_candidate->shape());
|
||||||
|
if ((others_root->operand(operand_num)->opcode() ==
|
||||||
|
HloOpcode::kConvert) &&
|
||||||
|
eq_shape) {
|
||||||
|
// Nothing to be done.
|
||||||
|
} else {
|
||||||
|
replica = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (replica) {
|
||||||
|
kspecial_convert.insert(operand_num);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kspecial_convert;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restructuring the conditional instruction as follows:
|
||||||
|
// i.e., %result = conditional() becomes
|
||||||
|
// x = conditional()
|
||||||
|
// y.{0..n} = gte(x, {0..n})
|
||||||
|
// z = tuple(y.0, y.1, ...y.n)
|
||||||
|
// Doing so ensures that we can accommodate the possible shape-change of the
|
||||||
|
// conditional when the instructions are hoisted.
|
||||||
|
Status RestructureConditionalInstruction(HloComputation* computation,
|
||||||
|
HloInstruction* conditional) {
|
||||||
|
HloInstruction* old_root = computation->root_instruction();
|
||||||
|
std::vector<HloInstruction*> new_operands;
|
||||||
|
int cur_index = 0;
|
||||||
|
for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
|
||||||
|
++cur_index) {
|
||||||
|
new_operands.push_back(
|
||||||
|
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||||
|
ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
|
||||||
|
conditional, cur_index)));
|
||||||
|
}
|
||||||
|
HloInstruction* new_tuple =
|
||||||
|
computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
|
||||||
|
if (old_root == conditional) {
|
||||||
|
computation->set_root_instruction(new_tuple);
|
||||||
|
} else {
|
||||||
|
std::vector<HloInstruction*> new_tuple_users;
|
||||||
|
for (auto conditional_user : conditional->users()) {
|
||||||
|
auto is_new_gte = absl::c_find_if(
|
||||||
|
new_operands,
|
||||||
|
[&](HloInstruction* instr) { return instr == conditional_user; });
|
||||||
|
if (is_new_gte == new_operands.end()) {
|
||||||
|
new_tuple_users.push_back(conditional_user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto new_tuple_user : new_tuple_users) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
conditional->ReplaceUseWith(new_tuple_user, new_tuple));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VLOG(2) << "computation after root restructure:\n" << computation->ToString();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
|
||||||
|
bool is_layout_sensitive) {
|
||||||
|
int branch_count = conditional->branch_count();
|
||||||
|
if (branch_count <= 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
HloInstruction* old_root =
|
||||||
|
conditional->branch_computation(0)->root_instruction();
|
||||||
|
if (old_root->opcode() != HloOpcode::kTuple) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
|
||||||
|
// Identify the gte using `index'.
|
||||||
|
auto find_gte = [](const HloInstruction* conditional_result,
|
||||||
|
int64 index) -> HloInstruction* {
|
||||||
|
for (HloInstruction* instr : conditional_result->users()) {
|
||||||
|
if (instr->opcode() != HloOpcode::kGetTupleElement) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (instr->tuple_index() == index) {
|
||||||
|
return instr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Captures tuple indices refering to converts to be rematerialized/hoisted.
|
||||||
|
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
|
||||||
|
old_root, branch_count, conditional, is_layout_sensitive);
|
||||||
|
|
||||||
|
// Exit if we cannot find any converts to be hoisted.
|
||||||
|
if (kspecial_convert.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
RestructureConditionalInstruction(conditional->parent(), conditional));
|
||||||
|
|
||||||
|
for (int branch = 0; branch < branch_count; branch++) {
|
||||||
|
old_root = conditional->branch_computation(branch)->root_instruction();
|
||||||
|
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
|
||||||
|
std::vector<HloInstruction*> new_operands(old_root->operand_count());
|
||||||
|
std::unordered_set<HloInstruction*> to_hoist_set;
|
||||||
|
|
||||||
|
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||||
|
++operand_num) {
|
||||||
|
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
|
||||||
|
operand_num;
|
||||||
|
}
|
||||||
|
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||||
|
++operand_num) {
|
||||||
|
HloInstruction* hoist = old_root->mutable_operand(operand_num);
|
||||||
|
if (!kspecial_convert.contains(operand_num)) {
|
||||||
|
new_operands[operand_num] = old_root->mutable_operand(operand_num);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
to_hoist_set.insert(hoist);
|
||||||
|
int64 new_tuple_count = old_root->operand_count();
|
||||||
|
|
||||||
|
// Replace the hoisted instr in the tuple with the operand/operands.
|
||||||
|
// We will replace at least one of the operands of the hoist at the
|
||||||
|
// tuple place; the rest will be added at the end.
|
||||||
|
bool inplace = true;
|
||||||
|
CHECK(!hoist->operands().empty());
|
||||||
|
for (HloInstruction* prod : hoist->operands()) {
|
||||||
|
if (inplace) {
|
||||||
|
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
|
||||||
|
new_operands[map_inst_to_tuple_index[hoist]] = prod;
|
||||||
|
inplace = false;
|
||||||
|
} else {
|
||||||
|
map_inst_to_tuple_index[prod] = new_tuple_count++;
|
||||||
|
new_operands.push_back(prod);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the new root instruction.
|
||||||
|
HloComputation* cur_branch = conditional->branch_computation(branch);
|
||||||
|
HloInstruction* new_branch_root =
|
||||||
|
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
|
||||||
|
// The shape can vary since the operands to convert are now
|
||||||
|
// being returned through the branches' root.
|
||||||
|
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
|
||||||
|
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
|
||||||
|
|
||||||
|
// Only one of the branches needs to change the conditional->parent().
|
||||||
|
if (branch != 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
HloComputation* conditional_parent = conditional->parent();
|
||||||
|
HloInstruction* newconditional =
|
||||||
|
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
|
||||||
|
cur_branch->root_instruction()->shape(),
|
||||||
|
conditional->mutable_operand(0),
|
||||||
|
absl::MakeSpan(conditional->branch_computations()),
|
||||||
|
absl::MakeSpan(conditional->operands()).subspan(1)));
|
||||||
|
// Ensure that all the users of conditional refer to the new one.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
|
||||||
|
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
|
||||||
|
conditional = newconditional;
|
||||||
|
// Add the hoisted instructions in the parent.
|
||||||
|
for (HloInstruction* hoist : to_hoist_set) {
|
||||||
|
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
|
||||||
|
int64 hoist_index = map_inst_to_tuple_index[hoist];
|
||||||
|
// Find out the gte that captured the hoisted instr result.
|
||||||
|
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
|
||||||
|
CHECK(gte_hoist != nullptr);
|
||||||
|
std::vector<HloInstruction*> new_operands;
|
||||||
|
for (HloInstruction* op : hoist->operands()) {
|
||||||
|
HloInstruction* gte = conditional_parent->AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(
|
||||||
|
op->shape(), conditional, map_inst_to_tuple_index[op]));
|
||||||
|
new_operands.push_back(gte);
|
||||||
|
}
|
||||||
|
HloInstruction* hoisted = conditional_parent->AddInstruction(
|
||||||
|
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
|
||||||
|
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
|
||||||
|
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
|
||||||
|
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
|
||||||
|
}
|
||||||
|
// No need to explicitly delete a hoisted instruction since if its dead
|
||||||
|
// then the subsequent DCE will remove it.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Hoist identical ops out of the conditional. The definition of identical
|
// Hoist identical ops out of the conditional. The definition of identical
|
||||||
// are the shape of the operands are identical and their properties are
|
// are the shape of the operands are identical and their properties are
|
||||||
// identical. Will start from the root instruction of each branch and get
|
// identical. Will start from the root instruction of each branch and get
|
||||||
// the identical ops to hoist.
|
// the identical ops to hoist.
|
||||||
StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
||||||
bool is_layout_sensitive) {
|
bool is_layout_sensitive) {
|
||||||
|
VLOG(1) << " visiting conditional:" << conditional->ToString();
|
||||||
int branch_count = conditional->branch_count();
|
int branch_count = conditional->branch_count();
|
||||||
if (branch_count <= 0) {
|
if (branch_count <= 0) {
|
||||||
return false;
|
return false;
|
||||||
@ -399,7 +616,7 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (visitors[0].HoistInstructionSize() <= 1) {
|
if (visitors[0].HoistInstructionSize() < 1) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -442,7 +659,6 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
|||||||
RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(),
|
RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(),
|
||||||
conditional->branch_computation(i)));
|
conditional->branch_computation(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -451,9 +667,7 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
|||||||
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
|
|
||||||
// Gather all the conditional ops in our module. We do this ahead of time so
|
if (pursue_full_conditional_code_motion_) {
|
||||||
// we don't have to worry about mutating the lists of computations or
|
|
||||||
// instructions as we iterate.
|
|
||||||
std::vector<HloInstruction*> conditional_ops;
|
std::vector<HloInstruction*> conditional_ops;
|
||||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||||
@ -464,13 +678,44 @@ StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (HloInstruction* conditional_op : conditional_ops) {
|
for (HloInstruction* conditional_op : conditional_ops) {
|
||||||
TF_ASSIGN_OR_RETURN(bool result, MergeIdenticalElements(
|
TF_ASSIGN_OR_RETURN(
|
||||||
conditional_op, is_layout_sensitive_));
|
bool result,
|
||||||
|
MergeIdenticalElements(conditional_op, is_layout_sensitive_));
|
||||||
changed |= result;
|
changed |= result;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (changed) {
|
if (changed) {
|
||||||
HloPassPipeline subpipeline("after_conditional_code_motion");
|
HloPassPipeline subpipeline("after_conditional_code_motion");
|
||||||
|
subpipeline.AddPass<HloDCE>();
|
||||||
|
subpipeline.AddPass<TupleSimplifier>();
|
||||||
|
subpipeline.AddPass<HloDCE>();
|
||||||
|
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
||||||
|
changed |= cleanup_changed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handling convert rematerialization/hoisting
|
||||||
|
{
|
||||||
|
std::vector<HloInstruction*> conditional_ops;
|
||||||
|
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||||
|
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||||
|
if (instr->opcode() == HloOpcode::kConditional) {
|
||||||
|
conditional_ops.push_back(instr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (HloInstruction* conditional_op : conditional_ops) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
bool convert_result,
|
||||||
|
ConvertSpecialMove(conditional_op, is_layout_sensitive_));
|
||||||
|
changed |= convert_result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (changed) {
|
||||||
|
HloPassPipeline subpipeline(
|
||||||
|
"after_conditional_code_motion_after_convert_hoisting");
|
||||||
|
subpipeline.AddPass<HloDCE>();
|
||||||
subpipeline.AddPass<TupleSimplifier>();
|
subpipeline.AddPass<TupleSimplifier>();
|
||||||
subpipeline.AddPass<HloDCE>();
|
subpipeline.AddPass<HloDCE>();
|
||||||
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
||||||
|
@ -23,7 +23,11 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// HLO pass that moves identical ops out of conditional.
|
// ConditionalCodeMotion specializes in hoisting/rematerializing
|
||||||
|
// unconditional converts in the default mode.
|
||||||
|
// When pursue_full_conditional_code_motion_ is set to true, the
|
||||||
|
// full HLO pass moves identical ops out of a conditional in addition to moving
|
||||||
|
// converts.
|
||||||
// - The definition of identical are the shape of the operands are identical
|
// - The definition of identical are the shape of the operands are identical
|
||||||
// and their properties are identical.
|
// and their properties are identical.
|
||||||
// - Currently, only some types of instructions is supported.
|
// - Currently, only some types of instructions is supported.
|
||||||
@ -35,13 +39,18 @@ class ConditionalCodeMotion : public HloModulePass {
|
|||||||
public:
|
public:
|
||||||
// If is_layout_sensitive is true, then the hoist process preserves layout
|
// If is_layout_sensitive is true, then the hoist process preserves layout
|
||||||
// during identical comparison. Otherwise, layout is ignored.
|
// during identical comparison. Otherwise, layout is ignored.
|
||||||
explicit ConditionalCodeMotion(bool is_layout_sensitive = true)
|
explicit ConditionalCodeMotion(
|
||||||
: is_layout_sensitive_(is_layout_sensitive) {}
|
bool is_layout_sensitive = true,
|
||||||
|
bool pursue_full_conditional_code_motion = false)
|
||||||
|
: is_layout_sensitive_(is_layout_sensitive),
|
||||||
|
pursue_full_conditional_code_motion_(
|
||||||
|
pursue_full_conditional_code_motion) {}
|
||||||
absl::string_view name() const override { return "conditional-code-motion"; }
|
absl::string_view name() const override { return "conditional-code-motion"; }
|
||||||
StatusOr<bool> Run(HloModule* module) override;
|
StatusOr<bool> Run(HloModule* module) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const bool is_layout_sensitive_;
|
const bool is_layout_sensitive_;
|
||||||
|
const bool pursue_full_conditional_code_motion_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -38,7 +38,86 @@ namespace {
|
|||||||
using ConditionalCodeMotionTest = HloTestBase;
|
using ConditionalCodeMotionTest = HloTestBase;
|
||||||
namespace op = xla::testing::opcode_matchers;
|
namespace op = xla::testing::opcode_matchers;
|
||||||
|
|
||||||
TEST_F(ConditionalCodeMotionTest, DoNotMoveConvertOut) {
|
TEST_F(ConditionalCodeMotionTest, MoveSubsetTupleOut) {
|
||||||
|
absl::string_view hlo_string =
|
||||||
|
R"(
|
||||||
|
HloModule RemoveDotOpOut
|
||||||
|
|
||||||
|
on_true {
|
||||||
|
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
|
||||||
|
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
|
||||||
|
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
|
||||||
|
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
|
||||||
|
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.2894, %reshape.8493)
|
||||||
|
}
|
||||||
|
|
||||||
|
on_false {
|
||||||
|
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
|
||||||
|
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
|
||||||
|
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
|
||||||
|
%add = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||||
|
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
|
||||||
|
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.3604, %add)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
pred.1 = pred[] parameter(0)
|
||||||
|
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
|
||||||
|
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||||
|
conditional = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||||
|
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
||||||
|
get-first-index.2 = f32[2,512,364]{2,1,0} get-tuple-element(conditional), index=1
|
||||||
|
ROOT result = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(get-first-index, get-first-index.2)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
|
ConditionalCodeMotion pass(true, true);
|
||||||
|
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||||
|
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
|
||||||
|
absl::string_view hlo_string =
|
||||||
|
R"(
|
||||||
|
HloModule RemoveDotOpOut
|
||||||
|
|
||||||
|
on_true {
|
||||||
|
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
|
||||||
|
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
|
||||||
|
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
|
||||||
|
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493)
|
||||||
|
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493)
|
||||||
|
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
|
||||||
|
}
|
||||||
|
|
||||||
|
on_false {
|
||||||
|
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
|
||||||
|
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
|
||||||
|
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
|
||||||
|
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||||
|
%sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||||
|
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
|
||||||
|
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
pred.1 = pred[] parameter(0)
|
||||||
|
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
|
||||||
|
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||||
|
ROOT conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
|
ConditionalCodeMotion pass(true, true);
|
||||||
|
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||||
|
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ConditionalCodeMotionTest, MoveConvertOut) {
|
||||||
absl::string_view hlo_string =
|
absl::string_view hlo_string =
|
||||||
R"(
|
R"(
|
||||||
HloModule RemoveDotOpOut
|
HloModule RemoveDotOpOut
|
||||||
@ -65,12 +144,16 @@ ENTRY main {
|
|||||||
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||||
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||||
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
||||||
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index)
|
add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
|
||||||
|
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1)
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
ConditionalCodeMotion pass;
|
ConditionalCodeMotion pass(true, true);
|
||||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||||
|
|
||||||
|
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||||
|
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert()))));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
|
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
|
||||||
@ -123,7 +206,7 @@ ENTRY main {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
ConditionalCodeMotion pass;
|
ConditionalCodeMotion pass(true, true);
|
||||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||||
|
|
||||||
const HloInstruction* conditional =
|
const HloInstruction* conditional =
|
||||||
@ -181,7 +264,7 @@ ENTRY main {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
ConditionalCodeMotion pass;
|
ConditionalCodeMotion pass(true, true);
|
||||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||||
const HloInstruction* conditional =
|
const HloInstruction* conditional =
|
||||||
FindInstruction(module.get(), "conditional");
|
FindInstruction(module.get(), "conditional");
|
||||||
@ -245,7 +328,7 @@ ENTRY main {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
ConditionalCodeMotion pass;
|
ConditionalCodeMotion pass(true, true);
|
||||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||||
|
|
||||||
const HloInstruction* conditional =
|
const HloInstruction* conditional =
|
||||||
@ -317,7 +400,7 @@ ENTRY main {
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
ConditionalCodeMotion pass;
|
ConditionalCodeMotion pass(true, true);
|
||||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -390,7 +473,7 @@ ENTRY main {
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||||
ConditionalCodeMotion pass;
|
ConditionalCodeMotion pass(true, true);
|
||||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||||
const HloInstruction* conditional =
|
const HloInstruction* conditional =
|
||||||
FindInstruction(module.get(), "conditional");
|
FindInstruction(module.get(), "conditional");
|
||||||
|
@ -226,6 +226,11 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
|
|||||||
dims_to_keep.push_back(dim);
|
dims_to_keep.push_back(dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We support fast codegen for three cases:
|
||||||
|
// 1) Row reduction: (K, R)
|
||||||
|
// 2) Column reduction: (K, R, K)
|
||||||
|
// 3) "Batched" row reduction: (R, K, R)
|
||||||
if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||||
dims_to_keep) &&
|
dims_to_keep) &&
|
||||||
!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||||
|
@ -77,8 +77,6 @@ class KernelThunk : public Thunk {
|
|||||||
// Will be set by IrEmitterUnnested.
|
// Will be set by IrEmitterUnnested.
|
||||||
LaunchDimensions launch_dimensions_;
|
LaunchDimensions launch_dimensions_;
|
||||||
|
|
||||||
// Describes how to load this kernel. ExecuteOnStream reuses this loader
|
|
||||||
// specification for all executions.
|
|
||||||
mutable tensorflow::mutex mutex_;
|
mutable tensorflow::mutex mutex_;
|
||||||
|
|
||||||
// Loaded kernels for each `StreamExecutor`. Requires pointer stability of
|
// Loaded kernels for each `StreamExecutor`. Requires pointer stability of
|
||||||
|
@ -3908,6 +3908,10 @@ const string& HloInstruction::outfeed_config() const {
|
|||||||
return Cast<HloOutfeedInstruction>(this)->outfeed_config();
|
return Cast<HloOutfeedInstruction>(this)->outfeed_config();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void HloInstruction::set_outfeed_config(const string& config) {
|
||||||
|
return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
|
||||||
|
}
|
||||||
|
|
||||||
const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
|
const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
|
||||||
return Cast<HloCollectiveInstruction>(this)->replica_groups();
|
return Cast<HloCollectiveInstruction>(this)->replica_groups();
|
||||||
}
|
}
|
||||||
|
@ -1755,6 +1755,9 @@ class HloInstruction {
|
|||||||
// Returns the config for the Outfeed instruction.
|
// Returns the config for the Outfeed instruction.
|
||||||
const string& outfeed_config() const;
|
const string& outfeed_config() const;
|
||||||
|
|
||||||
|
// Delegates to HloOutfeedInstruction::set_outfeed_config.
|
||||||
|
void set_outfeed_config(const string& config);
|
||||||
|
|
||||||
// Returns the shape for the Outfeed instruction.
|
// Returns the shape for the Outfeed instruction.
|
||||||
const Shape& outfeed_shape() const;
|
const Shape& outfeed_shape() const;
|
||||||
|
|
||||||
|
@ -1141,6 +1141,7 @@ class HloOutfeedInstruction : public HloInstruction {
|
|||||||
const Shape& outfeed_shape() const { return outfeed_shape_; }
|
const Shape& outfeed_shape() const { return outfeed_shape_; }
|
||||||
// Returns the config for the Outfeed instruction.
|
// Returns the config for the Outfeed instruction.
|
||||||
const string& outfeed_config() const { return outfeed_config_; }
|
const string& outfeed_config() const { return outfeed_config_; }
|
||||||
|
void set_outfeed_config(const string& config) { outfeed_config_ = config; }
|
||||||
// Returns a serialized representation of this instruction.
|
// Returns a serialized representation of this instruction.
|
||||||
HloInstructionProto ToProto() const override;
|
HloInstructionProto ToProto() const override;
|
||||||
|
|
||||||
|
@ -29,36 +29,78 @@ StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
|
|||||||
// Propagate the operand subshapes.
|
// Propagate the operand subshapes.
|
||||||
for (int operand_idx = 0; operand_idx < instruction->operand_count();
|
for (int operand_idx = 0; operand_idx < instruction->operand_count();
|
||||||
++operand_idx) {
|
++operand_idx) {
|
||||||
modified |=
|
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||||
PropagateSubshapes(instruction->operand(operand_idx)->shape(),
|
ShapeUtil::GetLeafShapes(
|
||||||
instruction->fused_parameter(operand_idx));
|
instruction->operand(operand_idx)->shape())) {
|
||||||
|
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
||||||
|
modified |= Propagate(indexed_shape.index,
|
||||||
|
instruction->fused_parameter(operand_idx),
|
||||||
|
memory_space);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Propagate output subshapes.
|
// Propagate output subshapes.
|
||||||
modified |= PropagateSubshapes(instruction->shape(),
|
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||||
instruction->fused_expression_root());
|
ShapeUtil::GetLeafShapes(instruction->shape())) {
|
||||||
|
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
||||||
|
modified |=
|
||||||
|
Propagate(indexed_shape.index,
|
||||||
|
instruction->fused_expression_root(), memory_space);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return modified;
|
return modified;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemorySpacePropagation::PropagateSubshapes(
|
bool MemorySpacePropagation::Propagate(ShapeIndexView index,
|
||||||
const Shape& caller_shape, const HloInstruction* callee_instruction) const {
|
const HloInstruction* callee_instruction,
|
||||||
|
int64 memory_space) const {
|
||||||
bool modified = false;
|
bool modified = false;
|
||||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
|
||||||
ShapeUtil::GetLeafShapes(caller_shape)) {
|
|
||||||
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
|
||||||
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
|
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
|
||||||
callee_instruction, indexed_shape.index);
|
callee_instruction, index.ToShapeIndex());
|
||||||
|
|
||||||
for (const HloPosition& position : value.positions()) {
|
for (const HloPosition& position : value.positions()) {
|
||||||
Shape* shape = ShapeUtil::GetMutableSubshape(
|
HloInstruction* instruction = position.instruction;
|
||||||
position.instruction->mutable_shape(), position.index);
|
Shape* shape = ShapeUtil::GetMutableSubshape(instruction->mutable_shape(),
|
||||||
if (shape->layout().memory_space() != memory_space) {
|
position.index);
|
||||||
|
if (shape->layout().memory_space() == memory_space) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
shape->mutable_layout()->set_memory_space(memory_space);
|
shape->mutable_layout()->set_memory_space(memory_space);
|
||||||
modified = true;
|
modified = true;
|
||||||
|
|
||||||
|
// For fusion outputs, propagate the memory space to the fusion root.
|
||||||
|
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||||
|
Propagate(position.index, instruction->fused_expression_root(),
|
||||||
|
memory_space);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const HloInstruction* parent_fusion =
|
||||||
|
instruction->parent()->FusionInstruction();
|
||||||
|
// For nested fusion roots, pop one level up and propagate the memory space
|
||||||
|
// to the output of the calling fusion instruction.
|
||||||
|
if (instruction == instruction->parent()->root_instruction() &&
|
||||||
|
parent_fusion->parent()->IsFusionComputation()) {
|
||||||
|
Propagate(position.index, parent_fusion, memory_space);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For nested fusion parameters, pop one level up and propagate the memory
|
||||||
|
// space to the operand of the calling fusion instruction.
|
||||||
|
if (instruction->opcode() == HloOpcode::kParameter &&
|
||||||
|
parent_fusion->parent()->IsFusionComputation()) {
|
||||||
|
const HloInstruction* fusion_operand =
|
||||||
|
parent_fusion->operand(instruction->parameter_number());
|
||||||
|
Propagate(position.index, fusion_operand, memory_space);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const HloUse& use : value.uses()) {
|
||||||
|
// For fusion uses, propagate the memory space to the fusion parameter.
|
||||||
|
if (use.instruction->opcode() == HloOpcode::kFusion) {
|
||||||
|
modified |= Propagate(
|
||||||
|
use.operand_index,
|
||||||
|
use.instruction->fused_parameter(use.operand_number), memory_space);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return modified;
|
return modified;
|
||||||
|
@ -31,12 +31,11 @@ class MemorySpacePropagation : public HloModulePass {
|
|||||||
StatusOr<bool> Run(HloModule* module) override;
|
StatusOr<bool> Run(HloModule* module) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Given the caller shape (operand or output) and its corresponding
|
// Given the shape index (operand or output) and its corresponding instruction
|
||||||
// insturction in the fused computation (parameter or root), propagates the
|
// in the fused computation (parameter or root), propagates the memory space
|
||||||
// memory space to all the subshapes in the callee side. Returns true if the
|
// in the callee side. Returns true if the module is modified.
|
||||||
// module is modified.
|
bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction,
|
||||||
bool PropagateSubshapes(const Shape& caller_shape,
|
int64 memory_space) const;
|
||||||
const HloInstruction* callee_instruction) const;
|
|
||||||
|
|
||||||
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
||||||
};
|
};
|
||||||
|
@ -199,5 +199,153 @@ TEST_F(MemorySpacePropagationTest, TupleOutput) {
|
|||||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
|
||||||
|
// Tests propagating the memory space to nested fusions on the input side.
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule NestedFusion
|
||||||
|
|
||||||
|
%bitcast_fusion {
|
||||||
|
%bf_param = s32[3,2]{0,1:T(128)} parameter(0)
|
||||||
|
ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param)
|
||||||
|
}
|
||||||
|
|
||||||
|
%fused_computation {
|
||||||
|
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||||
|
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||||
|
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||||
|
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||||
|
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||||
|
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||||
|
%param_0.1 = s32[3,2]{0,1:T(128)} parameter(0)
|
||||||
|
%fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
|
||||||
|
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %entry {
|
||||||
|
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
|
||||||
|
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||||
|
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||||
|
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
|
||||||
|
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||||
|
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||||
|
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||||
|
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
absl::string_view expected_hlo_string = R"(
|
||||||
|
HloModule NestedFusion
|
||||||
|
|
||||||
|
%bitcast_fusion {
|
||||||
|
%bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0)
|
||||||
|
ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param)
|
||||||
|
}
|
||||||
|
|
||||||
|
%fused_computation {
|
||||||
|
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||||
|
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||||
|
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||||
|
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
|
||||||
|
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||||
|
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||||
|
%param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0)
|
||||||
|
%fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
|
||||||
|
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %entry {
|
||||||
|
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
|
||||||
|
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||||
|
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||||
|
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
|
||||||
|
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||||
|
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||||
|
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||||
|
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnUnverifiedModule(hlo_string));
|
||||||
|
MemorySpacePropagation memory_space_propagation;
|
||||||
|
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||||
|
TF_EXPECT_OK(Verify(module.get()));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto ref,
|
||||||
|
ParseAndReturnVerifiedModule(expected_hlo_string));
|
||||||
|
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
|
||||||
|
// Tests propagating the memory space to nested fusions on the output side.
|
||||||
|
absl::string_view hlo_string = R"(
|
||||||
|
HloModule NestedFusion
|
||||||
|
|
||||||
|
%bitcast_fusion {
|
||||||
|
%bf_param = s32[6]{0:T(128)} parameter(0)
|
||||||
|
ROOT %bitcast = s32[3,2]{0,1:T(128)} bitcast(%bf_param)
|
||||||
|
}
|
||||||
|
|
||||||
|
%fused_computation {
|
||||||
|
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||||
|
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||||
|
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||||
|
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||||
|
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||||
|
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||||
|
%param_0.1 = s32[6]{0:T(128)} parameter(0)
|
||||||
|
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||||
|
ROOT %fusion.1 = s32[3,2]{0,1:T(128)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %entry {
|
||||||
|
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||||
|
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||||
|
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||||
|
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||||
|
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||||
|
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||||
|
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||||
|
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
absl::string_view expected_hlo_string = R"(
|
||||||
|
HloModule NestedFusion
|
||||||
|
|
||||||
|
%bitcast_fusion {
|
||||||
|
%bf_param = s32[6]{0:T(128)S(1)} parameter(0)
|
||||||
|
ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param)
|
||||||
|
}
|
||||||
|
|
||||||
|
%fused_computation {
|
||||||
|
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||||
|
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||||
|
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||||
|
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
|
||||||
|
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||||
|
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||||
|
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
|
||||||
|
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
|
||||||
|
ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY %entry {
|
||||||
|
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||||
|
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||||
|
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||||
|
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||||
|
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||||
|
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||||
|
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||||
|
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnUnverifiedModule(hlo_string));
|
||||||
|
MemorySpacePropagation memory_space_propagation;
|
||||||
|
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||||
|
TF_EXPECT_OK(Verify(module.get()));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto ref,
|
||||||
|
ParseAndReturnVerifiedModule(expected_hlo_string));
|
||||||
|
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -552,7 +552,7 @@ class LowerToNVVMPass
|
|||||||
// TODO(csigg): Remove once we support replacing non-root ops.
|
// TODO(csigg): Remove once we support replacing non-root ops.
|
||||||
target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
|
target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
|
||||||
::mlir::gpu::YieldOp>();
|
::mlir::gpu::YieldOp>();
|
||||||
if (failed(mlir::applyFullConversion(m, target, patterns, &converter))) {
|
if (failed(mlir::applyFullConversion(m, target, patterns))) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -52,16 +52,26 @@ cc_library(
|
|||||||
name = "test_macros_header",
|
name = "test_macros_header",
|
||||||
testonly = True,
|
testonly = True,
|
||||||
hdrs = ["test_macros.h"],
|
hdrs = ["test_macros.h"],
|
||||||
deps = [
|
|
||||||
"//tensorflow/compiler/xla:types",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate a test_macros_${BACKEND} library per backend with the proper copts.
|
# Generate a test_macros_${BACKEND} library per backend with the proper copts.
|
||||||
generate_backend_test_macros()
|
generate_backend_test_macros()
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "manifest_checking_test",
|
||||||
|
testonly = True,
|
||||||
|
srcs = ["manifest_checking_test.cc"],
|
||||||
|
hdrs = ["manifest_checking_test.h"],
|
||||||
|
deps = [
|
||||||
|
":test_macros_header",
|
||||||
|
"//tensorflow/core:regexp_internal",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "test_utils",
|
name = "test_utils",
|
||||||
srcs = ["test_utils.cc"],
|
srcs = ["test_utils.cc"],
|
||||||
@ -136,6 +146,7 @@ cc_library(
|
|||||||
hdrs = ["hlo_test_base.h"],
|
hdrs = ["hlo_test_base.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":literal_test_util",
|
":literal_test_util",
|
||||||
|
":manifest_checking_test",
|
||||||
":test_utils",
|
":test_utils",
|
||||||
":verified_hlo_module",
|
":verified_hlo_module",
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
@ -193,6 +204,7 @@ cc_library(
|
|||||||
srcs = ["client_library_test_base.cc"],
|
srcs = ["client_library_test_base.cc"],
|
||||||
hdrs = ["client_library_test_base.h"],
|
hdrs = ["client_library_test_base.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":manifest_checking_test",
|
||||||
"//tensorflow/compiler/xla:array2d",
|
"//tensorflow/compiler/xla:array2d",
|
||||||
"//tensorflow/compiler/xla:array3d",
|
"//tensorflow/compiler/xla:array3d",
|
||||||
"//tensorflow/compiler/xla:array4d",
|
"//tensorflow/compiler/xla:array4d",
|
||||||
@ -273,6 +285,7 @@ cc_library(
|
|||||||
hdrs = ["local_client_test_base.h"],
|
hdrs = ["local_client_test_base.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":client_library_test_base",
|
":client_library_test_base",
|
||||||
|
":manifest_checking_test",
|
||||||
":verified_hlo_module",
|
":verified_hlo_module",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
@ -266,11 +266,6 @@ def generate_backend_test_macros(backends = []):
|
|||||||
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
|
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"//tensorflow/core/platform:logging",
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"//tensorflow/compiler/xla:types",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:regexp_internal",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/bitmap.h"
|
#include "tensorflow/core/lib/core/bitmap.h"
|
||||||
@ -62,7 +63,7 @@ std::vector<TestCase> ExpandUseBfloat16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// A client library test establishes an in-process XLA client connection.
|
// A client library test establishes an in-process XLA client connection.
|
||||||
class ClientLibraryTestBase : public ::testing::Test {
|
class ClientLibraryTestBase : public ManifestCheckingTest {
|
||||||
protected:
|
protected:
|
||||||
explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
|
explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
@ -67,7 +68,7 @@ namespace xla {
|
|||||||
// )
|
// )
|
||||||
//
|
//
|
||||||
// For a more detailed example, see "../tests/sample_text_test.cc".
|
// For a more detailed example, see "../tests/sample_text_test.cc".
|
||||||
class HloTestBase : public ::testing::Test {
|
class HloTestBase : public ManifestCheckingTest {
|
||||||
public:
|
public:
|
||||||
// Creates a new HLO module for a test. The module created will have
|
// Creates a new HLO module for a test. The module created will have
|
||||||
// TestName() for its name; it will also automatically populate its debug
|
// TestName() for its name; it will also automatically populate its debug
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
@ -75,7 +76,7 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// A base class for tests which exercise the LocalClient interface.
|
// A base class for tests which exercise the LocalClient interface.
|
||||||
class LocalClientTestBase : public ::testing::Test {
|
class LocalClientTestBase : public ManifestCheckingTest {
|
||||||
protected:
|
protected:
|
||||||
struct EigenThreadPoolWrapper;
|
struct EigenThreadPoolWrapper;
|
||||||
explicit LocalClientTestBase(se::Platform* platform = nullptr);
|
explicit LocalClientTestBase(se::Platform* platform = nullptr);
|
||||||
|
129
tensorflow/compiler/xla/tests/manifest_checking_test.cc
Normal file
129
tensorflow/compiler/xla/tests/manifest_checking_test.cc
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iterator>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/strings/ascii.h"
|
||||||
|
#include "absl/strings/str_split.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
|
||||||
|
// disabled - a sequence of regexps.
|
||||||
|
using ManifestT = absl::flat_hash_map<std::string, std::vector<std::string>>;
|
||||||
|
|
||||||
|
ManifestT ReadManifest() {
|
||||||
|
ManifestT manifest;
|
||||||
|
|
||||||
|
absl::string_view path = absl::NullSafeStringView(kDisabledManifestPath);
|
||||||
|
if (path.empty()) {
|
||||||
|
return manifest;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: parens are required to disambiguate vs function decl.
|
||||||
|
std::ifstream file_stream((std::string(path)));
|
||||||
|
std::string contents((std::istreambuf_iterator<char>(file_stream)),
|
||||||
|
std::istreambuf_iterator<char>());
|
||||||
|
|
||||||
|
std::vector<std::string> lines = absl::StrSplit(contents, '\n');
|
||||||
|
for (std::string& line : lines) {
|
||||||
|
auto comment = line.find("//");
|
||||||
|
if (comment != std::string::npos) {
|
||||||
|
line = line.substr(0, comment);
|
||||||
|
}
|
||||||
|
if (line.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
absl::StripTrailingAsciiWhitespace(&line);
|
||||||
|
std::vector<std::string> pieces = absl::StrSplit(line, ' ');
|
||||||
|
CHECK_GE(pieces.size(), 1);
|
||||||
|
auto& platforms = manifest[pieces[0]];
|
||||||
|
for (size_t i = 1; i < pieces.size(); ++i) {
|
||||||
|
platforms.push_back(pieces[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return manifest;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ManifestCheckingTest::SetUp() {
|
||||||
|
const testing::TestInfo* test_info =
|
||||||
|
testing::UnitTest::GetInstance()->current_test_info();
|
||||||
|
absl::string_view test_case_name = test_info->test_suite_name();
|
||||||
|
absl::string_view test_name = test_info->name();
|
||||||
|
VLOG(1) << "test_case_name: " << test_case_name;
|
||||||
|
VLOG(1) << "test_name: " << test_name;
|
||||||
|
|
||||||
|
// Remove the type suffix from the test case name.
|
||||||
|
if (const char* type_param = test_info->type_param()) {
|
||||||
|
VLOG(1) << "type_param: " << type_param;
|
||||||
|
size_t last_slash = test_case_name.rfind('/');
|
||||||
|
test_case_name = test_case_name.substr(0, last_slash);
|
||||||
|
VLOG(1) << "test_case_name: " << test_case_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the test instantiation name if it is present.
|
||||||
|
auto first_slash = test_case_name.find('/');
|
||||||
|
if (first_slash != test_case_name.npos) {
|
||||||
|
test_case_name.remove_prefix(first_slash + 1);
|
||||||
|
VLOG(1) << "test_case_name: " << test_case_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
ManifestT manifest = ReadManifest();
|
||||||
|
|
||||||
|
// If the test name ends with a slash followed by one or more characters,
|
||||||
|
// strip that off.
|
||||||
|
auto last_slash = test_name.rfind('/');
|
||||||
|
if (last_slash != test_name.npos) {
|
||||||
|
test_name = test_name.substr(0, last_slash);
|
||||||
|
VLOG(1) << "test_name: " << test_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
// First try full match: test_case_name.test_name
|
||||||
|
// If that fails, try to find just the test_case_name; this would disable all
|
||||||
|
// tests in the test case.
|
||||||
|
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
|
||||||
|
if (it == manifest.end()) {
|
||||||
|
it = manifest.find(test_case_name);
|
||||||
|
if (it == manifest.end()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect a full match vs. one of the platform regexps to disable the test.
|
||||||
|
const std::vector<std::string>& disabled_platforms = it->second;
|
||||||
|
auto platform_string = kTestPlatform;
|
||||||
|
for (const auto& s : disabled_platforms) {
|
||||||
|
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
|
||||||
|
GTEST_SKIP();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We didn't hit in the disabled manifest entries, so don't disable it.
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
35
tensorflow/compiler/xla/tests/manifest_checking_test.h
Normal file
35
tensorflow/compiler/xla/tests/manifest_checking_test.h
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// This class allows us to intercept the test name and use an arbitrary
|
||||||
|
// heuristic to decide whether the test case should be disabled. We
|
||||||
|
// determine whether the test case should be disabled by resolving the (test
|
||||||
|
// case name, test name) in a manifest file.
|
||||||
|
class ManifestCheckingTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
// This method runs before each test runs.
|
||||||
|
void SetUp() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
@ -15,93 +15,18 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <streambuf>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "absl/strings/str_cat.h"
|
|
||||||
#include "absl/strings/str_split.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/regexp.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
|
||||||
|
|
||||||
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
|
static bool InitModule() {
|
||||||
// disabled - a sequence of regexps.
|
kDisabledManifestPath = XLA_DISABLED_MANIFEST;
|
||||||
using ManifestT = absl::flat_hash_map<string, std::vector<string>>;
|
VLOG(1) << "kDisabledManifestPath: " << kDisabledManifestPath;
|
||||||
|
kTestPlatform = XLA_PLATFORM;
|
||||||
ManifestT ReadManifest() {
|
VLOG(1) << "kTestPlatform: " << kTestPlatform;
|
||||||
ManifestT manifest;
|
return false;
|
||||||
|
|
||||||
string path = XLA_DISABLED_MANIFEST;
|
|
||||||
if (path.empty()) {
|
|
||||||
return manifest;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::ifstream file_stream(path);
|
|
||||||
// Note: parens are required to disambiguate vs function decl.
|
|
||||||
string contents((std::istreambuf_iterator<char>(file_stream)),
|
|
||||||
std::istreambuf_iterator<char>());
|
|
||||||
|
|
||||||
std::vector<string> lines = absl::StrSplit(contents, '\n');
|
|
||||||
for (string& line : lines) {
|
|
||||||
auto comment = line.find("//");
|
|
||||||
if (comment != string::npos) {
|
|
||||||
line = line.substr(0, comment);
|
|
||||||
}
|
|
||||||
if (line.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
absl::StripTrailingAsciiWhitespace(&line);
|
|
||||||
std::vector<string> pieces = absl::StrSplit(line, ' ');
|
|
||||||
CHECK_GE(pieces.size(), 1);
|
|
||||||
auto& platforms = manifest[pieces[0]];
|
|
||||||
for (int64 i = 1; i < pieces.size(); ++i) {
|
|
||||||
platforms.push_back(pieces[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return manifest;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
static bool module_initialized = InitModule();
|
||||||
|
|
||||||
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
|
|
||||||
absl::string_view test_name) {
|
|
||||||
ManifestT manifest = ReadManifest();
|
|
||||||
|
|
||||||
// If the test name ends with a slash followed by one or more digits, strip
|
|
||||||
// that off; this is just a shard number, and matching on this would be
|
|
||||||
// unstable even if someone wanted to do it.
|
|
||||||
static LazyRE2 shard_num_pattern = {R"(/\d+$)"};
|
|
||||||
absl::string_view suffix;
|
|
||||||
if (RE2::PartialMatch(test_name, *shard_num_pattern, &suffix)) {
|
|
||||||
test_name.remove_suffix(suffix.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
// First try full match: test_case_name.test_name
|
|
||||||
// If that fails, try to find just the test_case_name; this would disable all
|
|
||||||
// tests in the test case.
|
|
||||||
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
|
|
||||||
if (it == manifest.end()) {
|
|
||||||
it = manifest.find(test_case_name);
|
|
||||||
if (it == manifest.end()) {
|
|
||||||
return std::string(test_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expect a full match vs. one of the platform regexps to disable the test.
|
|
||||||
const std::vector<string>& disabled_platforms = it->second;
|
|
||||||
string platform_string = XLA_PLATFORM;
|
|
||||||
for (const auto& s : disabled_platforms) {
|
|
||||||
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
|
|
||||||
return absl::StrCat("DISABLED_", test_name);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We didn't hit in the disabled manifest entries, so don't disable it.
|
|
||||||
return std::string(test_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -28,12 +28,6 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "absl/strings/string_view.h"
|
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
|
|
||||||
#define DISABLED_ON_CPU(X) X
|
#define DISABLED_ON_CPU(X) X
|
||||||
#define DISABLED_ON_GPU(X) X
|
#define DISABLED_ON_GPU(X) X
|
||||||
#define DISABLED_ON_GPU_ROCM(X) X
|
#define DISABLED_ON_GPU_ROCM(X) X
|
||||||
@ -79,117 +73,15 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// Reads a disabled manifest file to resolve whether test cases should be
|
inline const char *kDisabledManifestPath = nullptr;
|
||||||
// disabled on a particular platform. For a test that should be disabled,
|
inline const char *kTestPlatform = nullptr;
|
||||||
// returns DISABLED_ prepended to its name; otherwise returns the test name
|
|
||||||
// unmodified.
|
|
||||||
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
|
|
||||||
absl::string_view test_name);
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
// This is the internal "gtest" class instantiation -- it is identical to the
|
#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
|
||||||
// GTEST_TEST_ macro, except that we intercept the test name for potential
|
|
||||||
// modification by PrependDisabledIfIndicated. That file can use an arbitrary
|
|
||||||
// heuristic to decide whether the test case should be disabled, and we
|
|
||||||
// determine whether the test case should be disabled by resolving the (test
|
|
||||||
// case name, test name) in a manifest file.
|
|
||||||
#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \
|
|
||||||
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
|
|
||||||
: public parent_class { \
|
|
||||||
public: \
|
|
||||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
|
|
||||||
\
|
|
||||||
private: \
|
|
||||||
virtual void TestBody(); \
|
|
||||||
static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
|
|
||||||
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
|
|
||||||
test_name)); \
|
|
||||||
}; \
|
|
||||||
\
|
|
||||||
::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \
|
|
||||||
test_name)::test_info_ = \
|
|
||||||
::testing::RegisterTest( \
|
|
||||||
#test_case_name, \
|
|
||||||
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
|
|
||||||
.c_str(), \
|
|
||||||
nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \
|
|
||||||
return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \
|
|
||||||
}); \
|
|
||||||
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
|
|
||||||
|
|
||||||
// This is identical to the TEST_F macro from "gtest", but it potentially
|
#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
|
||||||
// disables the test based on an external manifest file, DISABLED_MANIFEST.
|
|
||||||
//
|
|
||||||
// Per usual, you can see what tests are available via --gunit_list_tests and
|
|
||||||
// choose to run tests that have been disabled via the manifest via
|
|
||||||
// --gunit_also_run_disabled_tests.
|
|
||||||
#define XLA_TEST_F(test_fixture, test_name) \
|
|
||||||
XLA_GTEST_TEST_(test_fixture, test_name, test_fixture)
|
|
||||||
|
|
||||||
// Likewise, this is identical to the TEST_P macro from "gtest", but
|
#define XLA_TYPED_TEST(CaseName, TestName) TYPED_TEST(CaseName, TestName)
|
||||||
// potentially disables the test based on the DISABLED_MANIFEST file.
|
|
||||||
//
|
|
||||||
// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will
|
|
||||||
// be properly expanded before the stringification occurs.
|
|
||||||
#define XLA_TEST_P_IMPL_(test_case_name, test_name) \
|
|
||||||
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
|
|
||||||
: public test_case_name { \
|
|
||||||
public: \
|
|
||||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
|
|
||||||
virtual void TestBody(); \
|
|
||||||
\
|
|
||||||
private: \
|
|
||||||
static int AddToRegistry() { \
|
|
||||||
::testing::UnitTest::GetInstance() \
|
|
||||||
->parameterized_test_registry() \
|
|
||||||
.GetTestCasePatternHolder<test_case_name>( \
|
|
||||||
#test_case_name, \
|
|
||||||
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
|
|
||||||
->AddTestPattern( \
|
|
||||||
#test_case_name, \
|
|
||||||
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
|
|
||||||
.c_str(), \
|
|
||||||
new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
|
|
||||||
test_case_name, test_name)>()); \
|
|
||||||
return 0; \
|
|
||||||
} \
|
|
||||||
static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
|
|
||||||
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
|
|
||||||
test_name)); \
|
|
||||||
}; \
|
|
||||||
int GTEST_TEST_CLASS_NAME_(test_case_name, \
|
|
||||||
test_name)::gtest_registering_dummy_ = \
|
|
||||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \
|
|
||||||
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
|
|
||||||
|
|
||||||
#define XLA_TEST_P(test_case_name, test_name) \
|
|
||||||
XLA_TEST_P_IMPL_(test_case_name, test_name)
|
|
||||||
|
|
||||||
// This is identical to the TEST_F macro from "gtest", but it potentially
|
|
||||||
// disables the test based on an external manifest file, DISABLED_MANIFEST.
|
|
||||||
#define XLA_TYPED_TEST(CaseName, TestName) \
|
|
||||||
template <typename gtest_TypeParam_> \
|
|
||||||
class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
|
|
||||||
: public CaseName<gtest_TypeParam_> { \
|
|
||||||
private: \
|
|
||||||
typedef CaseName<gtest_TypeParam_> TestFixture; \
|
|
||||||
typedef gtest_TypeParam_ TypeParam; \
|
|
||||||
virtual void TestBody(); \
|
|
||||||
}; \
|
|
||||||
bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
|
|
||||||
::testing::internal::TypeParameterizedTest< \
|
|
||||||
CaseName, \
|
|
||||||
::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
|
|
||||||
TestName)>, \
|
|
||||||
GTEST_TYPE_PARAMS_(CaseName)>:: \
|
|
||||||
Register( \
|
|
||||||
"", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
|
|
||||||
#CaseName, \
|
|
||||||
::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
|
|
||||||
0); \
|
|
||||||
template <typename gtest_TypeParam_> \
|
|
||||||
void GTEST_TEST_CLASS_NAME_(CaseName, \
|
|
||||||
TestName)<gtest_TypeParam_>::TestBody()
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||||
|
@ -719,6 +719,7 @@ tf_cuda_library(
|
|||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow/core:__pkg__",
|
"//tensorflow/core:__pkg__",
|
||||||
"//tensorflow/core/util:__pkg__",
|
"//tensorflow/core/util:__pkg__",
|
||||||
|
"//tensorflow/security/fuzzing:__subpackages__",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":allocation_description_proto_cc",
|
":allocation_description_proto_cc",
|
||||||
|
@ -153,16 +153,9 @@ limitations under the License.
|
|||||||
#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines
|
#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines
|
||||||
|
|
||||||
// Defines for sets of types.
|
// Defines for sets of types.
|
||||||
|
|
||||||
// TODO(b/111604096): Add uint32 and uint64 to TF_CALL_INTEGRAL_TYPES.
|
|
||||||
//
|
|
||||||
// The uint32 and uint64 types were introduced in 10/2017 to be used via XLA and
|
|
||||||
// thus were not included in TF_CALL_INTEGRAL_TYPES. Including them in
|
|
||||||
// TF_CALL_INTEGRAL_TYPES should only happen after evaluating the effect on the
|
|
||||||
// TF binary size and performance.
|
|
||||||
#define TF_CALL_INTEGRAL_TYPES(m) \
|
#define TF_CALL_INTEGRAL_TYPES(m) \
|
||||||
TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
|
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_int32(m) \
|
||||||
TF_CALL_uint8(m) TF_CALL_int8(m)
|
TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
|
||||||
|
|
||||||
#define TF_CALL_FLOAT_TYPES(m) \
|
#define TF_CALL_FLOAT_TYPES(m) \
|
||||||
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
|
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
|
||||||
@ -176,8 +169,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
|
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
|
||||||
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
|
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
|
||||||
TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \
|
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \
|
||||||
TF_CALL_int8(m)
|
TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
|
||||||
|
|
||||||
#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
|
#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
|
||||||
|
|
||||||
|
@ -238,11 +238,6 @@ int DataTypeSize(DataType dt) {
|
|||||||
TF_CALL_qint16(CASE);
|
TF_CALL_qint16(CASE);
|
||||||
TF_CALL_quint16(CASE);
|
TF_CALL_quint16(CASE);
|
||||||
|
|
||||||
// uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we
|
|
||||||
// don't want to define kernels for them at this stage to avoid binary
|
|
||||||
// bloat.
|
|
||||||
TF_CALL_uint32(CASE);
|
|
||||||
TF_CALL_uint64(CASE);
|
|
||||||
default:
|
default:
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -837,6 +837,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
|||||||
"ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
|
"ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
|
||||||
"Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
|
"Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
|
||||||
"RandomPoissonV2",
|
"RandomPoissonV2",
|
||||||
|
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
||||||
|
|
||||||
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
|
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
|
||||||
// but it can't generate any observable side-effect.
|
// but it can't generate any observable side-effect.
|
||||||
@ -850,12 +851,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
|||||||
// the same device_ordinal on the same host.
|
// the same device_ordinal on the same host.
|
||||||
"EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
|
"EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
|
||||||
"EnqueueTPUEmbeddingSparseTensorBatch",
|
"EnqueueTPUEmbeddingSparseTensorBatch",
|
||||||
"EnqueueTPUEmbeddingRaggedTensorBatch",
|
"EnqueueTPUEmbeddingRaggedTensorBatch"});
|
||||||
|
|
||||||
// SaveV2 and RestoreV2 should be allowed to operate in parallel on
|
|
||||||
// multiple hosts.
|
|
||||||
"SaveV2", "RestoreV2"});
|
|
||||||
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
|
||||||
return exemption->contains(op);
|
return exemption->contains(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4168,6 +4168,25 @@ tf_kernel_library(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "mlir_generated_op_gpu_tanh_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = if_mlir_generated_gpu_kernels_enabled(["mlir_generated_op_gpu_tanh_test.cc"]),
|
||||||
|
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
||||||
|
deps = [
|
||||||
|
":cwise_op",
|
||||||
|
":ops_testutil",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/common_runtime:device",
|
||||||
|
"//tensorflow/core/common_runtime:device_factory",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "nextafter_op",
|
name = "nextafter_op",
|
||||||
prefix = "nextafter_op",
|
prefix = "nextafter_op",
|
||||||
@ -4900,7 +4919,9 @@ tf_kernel_library(
|
|||||||
"topk_op_gpu_double.cu.cc",
|
"topk_op_gpu_double.cu.cc",
|
||||||
"topk_op_gpu_float.cu.cc",
|
"topk_op_gpu_float.cu.cc",
|
||||||
"topk_op_gpu_half.cu.cc",
|
"topk_op_gpu_half.cu.cc",
|
||||||
|
"topk_op_gpu_uint64.cu.cc",
|
||||||
"topk_op_gpu_int64.cu.cc",
|
"topk_op_gpu_int64.cu.cc",
|
||||||
|
"topk_op_gpu_uint32.cu.cc",
|
||||||
"topk_op_gpu_int32.cu.cc",
|
"topk_op_gpu_int32.cu.cc",
|
||||||
"topk_op_gpu_int16.cu.cc",
|
"topk_op_gpu_int16.cu.cc",
|
||||||
"topk_op_gpu_uint16.cu.cc",
|
"topk_op_gpu_uint16.cu.cc",
|
||||||
@ -6802,7 +6823,8 @@ filegroup(
|
|||||||
"cwise_op_minimum.cc",
|
"cwise_op_minimum.cc",
|
||||||
"cwise_op_mul_1.cc",
|
"cwise_op_mul_1.cc",
|
||||||
"cwise_op_mul_2.cc",
|
"cwise_op_mul_2.cc",
|
||||||
"cwise_op_neg.cc",
|
"cwise_op_neg_1.cc",
|
||||||
|
"cwise_op_neg_2.cc",
|
||||||
"cwise_op_pow.cc",
|
"cwise_op_pow.cc",
|
||||||
"cwise_op_real.cc",
|
"cwise_op_real.cc",
|
||||||
"cwise_op_reciprocal.cc",
|
"cwise_op_reciprocal.cc",
|
||||||
@ -8780,7 +8802,8 @@ exports_files([
|
|||||||
"cwise_op_mod.cc",
|
"cwise_op_mod.cc",
|
||||||
"cwise_op_mul_1.cc",
|
"cwise_op_mul_1.cc",
|
||||||
"cwise_op_mul_2.cc",
|
"cwise_op_mul_2.cc",
|
||||||
"cwise_op_neg.cc",
|
"cwise_op_neg_1.cc",
|
||||||
|
"cwise_op_neg_2.cc",
|
||||||
"cwise_op_not_equal_to_1.cc",
|
"cwise_op_not_equal_to_1.cc",
|
||||||
"cwise_op_not_equal_to_2.cc",
|
"cwise_op_not_equal_to_2.cc",
|
||||||
"cwise_op_round.cc",
|
"cwise_op_round.cc",
|
||||||
|
@ -116,8 +116,6 @@ REGISTER(qint8)
|
|||||||
REGISTER(quint16)
|
REGISTER(quint16)
|
||||||
REGISTER(qint16)
|
REGISTER(qint16)
|
||||||
REGISTER(qint32)
|
REGISTER(qint32)
|
||||||
REGISTER(uint32)
|
|
||||||
REGISTER(uint64)
|
|
||||||
|
|
||||||
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
|
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
|
||||||
!defined(__ANDROID_TYPES_FULL__)
|
!defined(__ANDROID_TYPES_FULL__)
|
||||||
|
@ -208,8 +208,6 @@ REGISTER_CONCAT(qint8);
|
|||||||
REGISTER_CONCAT(quint16);
|
REGISTER_CONCAT(quint16);
|
||||||
REGISTER_CONCAT(qint16);
|
REGISTER_CONCAT(qint16);
|
||||||
REGISTER_CONCAT(qint32);
|
REGISTER_CONCAT(qint32);
|
||||||
REGISTER_CONCAT(uint32);
|
|
||||||
REGISTER_CONCAT(uint64);
|
|
||||||
|
|
||||||
#undef REGISTER_CONCAT
|
#undef REGISTER_CONCAT
|
||||||
|
|
||||||
|
@ -211,7 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
|
|||||||
// the conversion from uint8 to quint8.
|
// the conversion from uint8 to quint8.
|
||||||
REGISTER_KERNEL(CPU, quint8);
|
REGISTER_KERNEL(CPU, quint8);
|
||||||
REGISTER_KERNEL(CPU, quint16);
|
REGISTER_KERNEL(CPU, quint16);
|
||||||
REGISTER_KERNEL(CPU, uint32);
|
|
||||||
#undef REGISTER_CPU_KERNEL
|
#undef REGISTER_CPU_KERNEL
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
@ -101,27 +101,21 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
|
|||||||
TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
|
TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
|
||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
|
||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
|
||||||
REGISTER_CPU_SWITCH(uint64);
|
|
||||||
|
|
||||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
|
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
|
||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
|
||||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
|
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
|
||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
|
||||||
REGISTER_GPU_SWITCH(uint64);
|
|
||||||
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
||||||
TF_CALL_uint32(REGISTER_GPU_SWITCH);
|
|
||||||
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
|
|
||||||
TF_CALL_bool(REGISTER_GPU_SWITCH);
|
|
||||||
TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
|
|
||||||
|
|
||||||
#undef REGISTER_CPU_SWITCH
|
#undef REGISTER_CPU_SWITCH
|
||||||
#undef REGISTER_CPU_REF_SWITCH
|
#undef REGISTER_CPU_REF_SWITCH
|
||||||
#undef REGISTER_GPU_SWITCH
|
#undef REGISTER_GPU_SWITCH
|
||||||
#undef REGISTER_GPU_REF_SWITCH
|
#undef REGISTER_GPU_REF_SWITCH
|
||||||
|
|
||||||
// Special GPU kernels for int32, string & resource handles. Requiring all
|
// Special GPU kernels for int32 and string.
|
||||||
// inputs and outputs to be in host memory.
|
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||||
// TODO(b/25387198): Also enable int32 in device memory.
|
// registration requires all int32 inputs and outputs to be in host memory.
|
||||||
#define REGISTER_GPU_HOST_KERNEL(type) \
|
#define REGISTER_GPU_HOST_KERNEL(type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("Switch") \
|
REGISTER_KERNEL_BUILDER(Name("Switch") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
@ -151,6 +145,8 @@ TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
|
|||||||
|
|
||||||
REGISTER_GPU_HOST_KERNEL(int32);
|
REGISTER_GPU_HOST_KERNEL(int32);
|
||||||
REGISTER_GPU_HOST_REF_KERNEL(int32);
|
REGISTER_GPU_HOST_REF_KERNEL(int32);
|
||||||
|
REGISTER_GPU_HOST_KERNEL(bool);
|
||||||
|
REGISTER_GPU_HOST_REF_KERNEL(bool);
|
||||||
REGISTER_GPU_HOST_KERNEL(tstring);
|
REGISTER_GPU_HOST_KERNEL(tstring);
|
||||||
REGISTER_GPU_HOST_REF_KERNEL(tstring);
|
REGISTER_GPU_HOST_REF_KERNEL(tstring);
|
||||||
REGISTER_GPU_HOST_KERNEL(ResourceHandle);
|
REGISTER_GPU_HOST_KERNEL(ResourceHandle);
|
||||||
@ -311,7 +307,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL);
|
|||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
|
||||||
REGISTER_GPU_KERNEL(bool);
|
REGISTER_GPU_KERNEL(bool);
|
||||||
REGISTER_GPU_REF_KERNEL(bool);
|
REGISTER_GPU_REF_KERNEL(bool);
|
||||||
REGISTER_GPU_KERNEL(uint64);
|
|
||||||
TF_CALL_variant(REGISTER_GPU_KERNEL);
|
TF_CALL_variant(REGISTER_GPU_KERNEL);
|
||||||
|
|
||||||
#undef REGISTER_GPU_KERNEL
|
#undef REGISTER_GPU_KERNEL
|
||||||
|
@ -19,8 +19,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
DEFINE_UNARY7(neg, Eigen::half, float, double, int32, int64, complex64,
|
DEFINE_UNARY4(neg, int8, int16, int32, int64);
|
||||||
complex128);
|
DEFINE_UNARY6(neg, Eigen::half, float, double, bfloat16, complex64, complex128);
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -16,8 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
REGISTER8(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32,
|
REGISTER4(UnaryOp, CPU, "Neg", functor::neg, int8, int16, int32, int64);
|
||||||
complex64, int64, complex128, bfloat16);
|
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
|
REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
|
||||||
@ -30,8 +29,7 @@ REGISTER_KERNEL_BUILDER(Name("Neg")
|
|||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64,
|
REGISTER3(UnaryOp, GPU, "Neg", functor::neg, int8, int16, int64);
|
||||||
complex64, complex128);
|
|
||||||
|
|
||||||
// A special GPU kernel for int32.
|
// A special GPU kernel for int32.
|
||||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
26
tensorflow/core/kernels/cwise_op_neg_2.cc
Normal file
26
tensorflow/core/kernels/cwise_op_neg_2.cc
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
REGISTER6(UnaryOp, CPU, "Neg", functor::neg, Eigen::half, float, double,
|
||||||
|
bfloat16, complex64, complex128);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, Eigen::half, float, double,
|
||||||
|
bfloat16, complex64, complex128);
|
||||||
|
#endif
|
||||||
|
} // namespace tensorflow
|
@ -220,8 +220,6 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
|
|||||||
break;
|
break;
|
||||||
TF_CALL_NUMBER_TYPES(CASE);
|
TF_CALL_NUMBER_TYPES(CASE);
|
||||||
TF_CALL_tstring(CASE);
|
TF_CALL_tstring(CASE);
|
||||||
TF_CALL_uint32(CASE);
|
|
||||||
TF_CALL_uint64(CASE);
|
|
||||||
// TODO(feihugis): figure out how to support variant tensors.
|
// TODO(feihugis): figure out how to support variant tensors.
|
||||||
#undef CASE
|
#undef CASE
|
||||||
default:
|
default:
|
||||||
|
@ -98,7 +98,6 @@ typedef Eigen::SyclDevice SYCLDevice;
|
|||||||
|
|
||||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||||
// uint32 not included in ALL_TYPES
|
// uint32 not included in ALL_TYPES
|
||||||
TF_CALL_uint32(REGISTER_KERNELS);
|
|
||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||||
// quint16 not included in QUANTIZIED_TYPES
|
// quint16 not included in QUANTIZIED_TYPES
|
||||||
TF_CALL_quint16(REGISTER_KERNELS);
|
TF_CALL_quint16(REGISTER_KERNELS);
|
||||||
|
@ -164,8 +164,6 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared {
|
|||||||
DynamicPartitionOp<T>)
|
DynamicPartitionOp<T>)
|
||||||
|
|
||||||
TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION);
|
TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION);
|
||||||
// For partitioning fingerprints.
|
|
||||||
TF_CALL_uint64(REGISTER_DYNAMIC_PARTITION);
|
|
||||||
#undef REGISTER_DYNAMIC_PARTITION
|
#undef REGISTER_DYNAMIC_PARTITION
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -45,6 +45,8 @@ DEFINE_SETZERO_CPU(Eigen::half);
|
|||||||
DEFINE_SETZERO_CPU(bfloat16);
|
DEFINE_SETZERO_CPU(bfloat16);
|
||||||
DEFINE_SETZERO_CPU(float);
|
DEFINE_SETZERO_CPU(float);
|
||||||
DEFINE_SETZERO_CPU(double);
|
DEFINE_SETZERO_CPU(double);
|
||||||
|
DEFINE_SETZERO_CPU(uint32);
|
||||||
|
DEFINE_SETZERO_CPU(uint64);
|
||||||
DEFINE_SETZERO_CPU(uint8);
|
DEFINE_SETZERO_CPU(uint8);
|
||||||
DEFINE_SETZERO_CPU(int8);
|
DEFINE_SETZERO_CPU(int8);
|
||||||
DEFINE_SETZERO_CPU(uint16);
|
DEFINE_SETZERO_CPU(uint16);
|
||||||
@ -96,6 +98,8 @@ DEFINE_SETONE_CPU(Eigen::half);
|
|||||||
DEFINE_SETONE_CPU(bfloat16);
|
DEFINE_SETONE_CPU(bfloat16);
|
||||||
DEFINE_SETONE_CPU(float);
|
DEFINE_SETONE_CPU(float);
|
||||||
DEFINE_SETONE_CPU(double);
|
DEFINE_SETONE_CPU(double);
|
||||||
|
DEFINE_SETONE_CPU(uint32);
|
||||||
|
DEFINE_SETONE_CPU(uint64);
|
||||||
DEFINE_SETONE_CPU(uint8);
|
DEFINE_SETONE_CPU(uint8);
|
||||||
DEFINE_SETONE_CPU(int8);
|
DEFINE_SETONE_CPU(int8);
|
||||||
DEFINE_SETONE_CPU(uint16);
|
DEFINE_SETONE_CPU(uint16);
|
||||||
@ -137,7 +141,6 @@ struct FillFunctor<Eigen::ThreadPoolDevice, T> {
|
|||||||
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
|
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
|
||||||
DEFINE_FILL_CPU(quint8);
|
DEFINE_FILL_CPU(quint8);
|
||||||
DEFINE_FILL_CPU(quint16);
|
DEFINE_FILL_CPU(quint16);
|
||||||
DEFINE_FILL_CPU(uint32);
|
|
||||||
#undef DEFINE_FILL_CPU
|
#undef DEFINE_FILL_CPU
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
@ -211,8 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
|
|||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
|
||||||
TF_CALL_quint16(REGISTER_GATHER_CPU);
|
TF_CALL_quint16(REGISTER_GATHER_CPU);
|
||||||
TF_CALL_qint16(REGISTER_GATHER_CPU);
|
TF_CALL_qint16(REGISTER_GATHER_CPU);
|
||||||
TF_CALL_uint32(REGISTER_GATHER_CPU);
|
|
||||||
TF_CALL_uint64(REGISTER_GATHER_CPU);
|
|
||||||
|
|
||||||
#undef REGISTER_GATHER_CPU
|
#undef REGISTER_GATHER_CPU
|
||||||
|
|
||||||
|
@ -122,7 +122,6 @@ REGISTER_SYCL_HOST_KERNEL(bool);
|
|||||||
|
|
||||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||||
REGISTER_GPU_KERNEL(Variant);
|
REGISTER_GPU_KERNEL(Variant);
|
||||||
TF_CALL_uint32(REGISTER_GPU_KERNEL);
|
|
||||||
|
|
||||||
#undef REGISTER_GPU_KERNEL
|
#undef REGISTER_GPU_KERNEL
|
||||||
|
|
||||||
|
@ -178,6 +178,9 @@ class MklAddNOp : public OpKernel {
|
|||||||
dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
|
dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<stream> fwd_cpu_stream;
|
||||||
|
fwd_cpu_stream.reset(CreateStream(ctx, cpu_engine));
|
||||||
|
|
||||||
// Create memory descriptor for MKL-DNN.
|
// Create memory descriptor for MKL-DNN.
|
||||||
// If all input in Tensorflow format, create block memory descriptor,
|
// If all input in Tensorflow format, create block memory descriptor,
|
||||||
// else convert TF format to MKL memory descriptor
|
// else convert TF format to MKL memory descriptor
|
||||||
@ -215,6 +218,7 @@ class MklAddNOp : public OpKernel {
|
|||||||
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
|
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
|
||||||
#endif
|
#endif
|
||||||
src.SetUsrMem(md, &src_tensor);
|
src.SetUsrMem(md, &src_tensor);
|
||||||
|
src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream);
|
||||||
inputs.push_back(src.GetOpMem());
|
inputs.push_back(src.GetOpMem());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,11 +244,10 @@ class MklAddNOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
|
AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
|
||||||
output_mkl_shape);
|
output_mkl_shape);
|
||||||
dst.SetUsrMemDataHandle(dst_tensor);
|
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
|
||||||
|
|
||||||
// Create Sum op, and submit net for execution.
|
// Create Sum op, and submit net for execution.
|
||||||
std::vector<primitive> net;
|
std::vector<primitive> net;
|
||||||
stream* fwd_cpu_stream = CreateStream(ctx, cpu_engine);
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
mkldnn::sum sum_op(sum_pd);
|
mkldnn::sum sum_op(sum_pd);
|
||||||
std::unordered_map<int, memory> net_args = {
|
std::unordered_map<int, memory> net_args = {
|
||||||
|
@ -281,11 +281,19 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
|||||||
std::shared_ptr<stream> fwd_stream) {
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
DCHECK_EQ(in_data.size(), context_.data_mem.size());
|
DCHECK_EQ(in_data.size(), context_.data_mem.size());
|
||||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
context_.data_mem_shdptr[i]->set_data_handle(
|
||||||
|
static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream);
|
||||||
|
}
|
||||||
|
context_.dst_mem->set_data_handle(
|
||||||
|
static_cast<void*>(dst_data.get_data_handle()), *fwd_stream);
|
||||||
|
#else
|
||||||
context_.data_mem_shdptr[i]->set_data_handle(
|
context_.data_mem_shdptr[i]->set_data_handle(
|
||||||
static_cast<void*>(in_data[i].get_data_handle()));
|
static_cast<void*>(in_data[i].get_data_handle()));
|
||||||
}
|
}
|
||||||
context_.dst_mem->set_data_handle(
|
context_.dst_mem->set_data_handle(
|
||||||
static_cast<void*>(dst_data.get_data_handle()));
|
static_cast<void*>(dst_data.get_data_handle()));
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
|
|
||||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||||
context_.data_mem[i] = *context_.data_mem_shdptr[i];
|
context_.data_mem[i] = *context_.data_mem_shdptr[i];
|
||||||
@ -788,11 +796,13 @@ class MklConcatOp : public OpKernel {
|
|||||||
dnn_shape_dst);
|
dnn_shape_dst);
|
||||||
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
|
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
|
||||||
|
|
||||||
|
std::shared_ptr<stream> fwd_cpu_stream;
|
||||||
|
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
|
||||||
|
|
||||||
if (dnn_shape_dst.IsMklTensor())
|
if (dnn_shape_dst.IsMklTensor())
|
||||||
dst_md = dnn_shape_dst.GetMklLayout();
|
dst_md = dnn_shape_dst.GetMklLayout();
|
||||||
dst.SetUsrMem(dst_md, dst_tensor);
|
dst.SetUsrMem(dst_md, dst_tensor);
|
||||||
std::shared_ptr<stream> fwd_cpu_stream;
|
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
|
||||||
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
auto concat_op = concat(concat_pd);
|
auto concat_op = concat(concat_pd);
|
||||||
std::unordered_map<int, memory> net_args = {
|
std::unordered_map<int, memory> net_args = {
|
||||||
@ -830,9 +840,10 @@ class MklConcatOp : public OpKernel {
|
|||||||
|
|
||||||
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
|
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
|
||||||
: dst_md;
|
: dst_md;
|
||||||
dst.SetUsrMem(dst_md, dst_tensor);
|
|
||||||
std::shared_ptr<stream> fwd_cpu_stream;
|
std::shared_ptr<stream> fwd_cpu_stream;
|
||||||
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
|
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
|
||||||
|
dst.SetUsrMem(dst_md, dst_tensor);
|
||||||
|
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
|
||||||
// Execute concat
|
// Execute concat
|
||||||
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
|
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
|
||||||
fwd_cpu_stream);
|
fwd_cpu_stream);
|
||||||
|
@ -75,6 +75,9 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
MklDnnData<T> src(&cpu_engine);
|
MklDnnData<T> src(&cpu_engine);
|
||||||
MklDnnData<float> dst(&cpu_engine);
|
MklDnnData<float> dst(&cpu_engine);
|
||||||
|
|
||||||
|
std::shared_ptr<stream> reorder_stream;
|
||||||
|
reorder_stream.reset(CreateStream(ctx, cpu_engine));
|
||||||
|
|
||||||
// If input is in MKL layout, then simply grab input layout; otherwise,
|
// If input is in MKL layout, then simply grab input layout; otherwise,
|
||||||
// construct input TF layout. For TF layout, although input shape
|
// construct input TF layout. For TF layout, although input shape
|
||||||
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
|
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
|
||||||
@ -85,6 +88,7 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
: memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc);
|
: memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc);
|
||||||
|
|
||||||
src.SetUsrMem(src_md, &src_tensor);
|
src.SetUsrMem(src_md, &src_tensor);
|
||||||
|
src.SetUsrMemDataHandle(&src_tensor, reorder_stream);
|
||||||
|
|
||||||
Tensor* output_tensor = nullptr;
|
Tensor* output_tensor = nullptr;
|
||||||
MklDnnShape output_mkl_shape;
|
MklDnnShape output_mkl_shape;
|
||||||
@ -129,6 +133,7 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape,
|
AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape,
|
||||||
output_mkl_shape);
|
output_mkl_shape);
|
||||||
dst.SetUsrMem(dst_md, output_tensor);
|
dst.SetUsrMem(dst_md, output_tensor);
|
||||||
|
dst.SetUsrMemDataHandle(output_tensor, reorder_stream);
|
||||||
|
|
||||||
// The quantization logic here for mode SCALED is similar to the logic
|
// The quantization logic here for mode SCALED is similar to the logic
|
||||||
// in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
|
// in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
|
||||||
@ -155,8 +160,6 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
// Also it does not define round_nearest (enum).
|
// Also it does not define round_nearest (enum).
|
||||||
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
|
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
std::shared_ptr<stream> reorder_stream;
|
|
||||||
reorder_stream.reset(CreateStream(ctx, cpu_engine));
|
|
||||||
std::vector<primitive> net;
|
std::vector<primitive> net;
|
||||||
|
|
||||||
// Create reorder primitive and then execute.
|
// Create reorder primitive and then execute.
|
||||||
|
@ -137,6 +137,7 @@ class MklLRNOp : public OpKernel {
|
|||||||
// that input is in NHWC layout with Channel being the last dimension.
|
// that input is in NHWC layout with Channel being the last dimension.
|
||||||
src_dnn_data.SetUsrMem(src_md, &src_tensor);
|
src_dnn_data.SetUsrMem(src_md, &src_tensor);
|
||||||
src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
|
src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
|
||||||
|
src_dnn_data.SetUsrMemDataHandle(&src_tensor, fwd_stream_);
|
||||||
|
|
||||||
// dst_dnn_data has the same shape as input.
|
// dst_dnn_data has the same shape as input.
|
||||||
dst_dnn_data.SetUsrMem(src_md);
|
dst_dnn_data.SetUsrMem(src_md);
|
||||||
@ -157,7 +158,7 @@ class MklLRNOp : public OpKernel {
|
|||||||
&output_tensor);
|
&output_tensor);
|
||||||
OP_REQUIRES_OK(context, context->status());
|
OP_REQUIRES_OK(context, context->status());
|
||||||
DCHECK(output_tensor != nullptr);
|
DCHECK(output_tensor != nullptr);
|
||||||
dst_dnn_data.SetUsrMemDataHandle(output_tensor);
|
dst_dnn_data.SetUsrMemDataHandle(output_tensor, fwd_stream_);
|
||||||
|
|
||||||
// Handle workspace required for MKL-DNN.
|
// Handle workspace required for MKL-DNN.
|
||||||
AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
|
AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
|
||||||
@ -393,6 +394,7 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
orig_input_dnn_shape.GetSizesAsMklDnnDims();
|
orig_input_dnn_shape.GetSizesAsMklDnnDims();
|
||||||
orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
|
orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
|
||||||
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
|
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
|
||||||
|
orig_input_dnn_data.SetUsrMemDataHandle(&orig_input_tensor, bwd_stream_);
|
||||||
|
|
||||||
// output_dnn_data has the same shape as original input
|
// output_dnn_data has the same shape as original input
|
||||||
output_dnn_data.SetUsrMem(orig_input_md);
|
output_dnn_data.SetUsrMem(orig_input_md);
|
||||||
@ -421,7 +423,7 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
orig_input_format, &output_tensor);
|
orig_input_format, &output_tensor);
|
||||||
OP_REQUIRES_OK(context, context->status());
|
OP_REQUIRES_OK(context, context->status());
|
||||||
DCHECK(output_tensor != nullptr);
|
DCHECK(output_tensor != nullptr);
|
||||||
output_dnn_data.SetUsrMemDataHandle(output_tensor);
|
output_dnn_data.SetUsrMemDataHandle(output_tensor, bwd_stream_);
|
||||||
|
|
||||||
// Create LRN primitive and add it to the net
|
// Create LRN primitive and add it to the net
|
||||||
// At this point, workspace is enabled, so we don't need
|
// At this point, workspace is enabled, so we don't need
|
||||||
|
@ -137,6 +137,7 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
|||||||
memory::dims out_strides =
|
memory::dims out_strides =
|
||||||
ReorderStrides(CalculateTFStrides(out_dims), perm);
|
ReorderStrides(CalculateTFStrides(out_dims), perm);
|
||||||
|
|
||||||
|
std::shared_ptr<stream> transpose_stream;
|
||||||
in.SetUsrMem(in_dims, in_strides, &in_tensor);
|
in.SetUsrMem(in_dims, in_strides, &in_tensor);
|
||||||
// Output dimensions are same as input dimensions. We adjust the layout
|
// Output dimensions are same as input dimensions. We adjust the layout
|
||||||
// using strides.
|
// using strides.
|
||||||
@ -144,16 +145,16 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
|||||||
|
|
||||||
std::vector<primitive> net;
|
std::vector<primitive> net;
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
std::shared_ptr<stream> transpose_stream;
|
|
||||||
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
|
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
|
||||||
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
|
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||||
|
in.SetUsrMemDataHandle(&in_tensor, transpose_stream);
|
||||||
|
out.SetUsrMemDataHandle(out_tensor, transpose_stream);
|
||||||
net.push_back(*(prim->GetPrimitive()));
|
net.push_back(*(prim->GetPrimitive()));
|
||||||
std::vector<MemoryArgsMap> net_args;
|
std::vector<MemoryArgsMap> net_args;
|
||||||
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
|
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
|
||||||
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
|
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
|
||||||
execute_primitives(net, transpose_stream, net_args);
|
execute_primitives(net, transpose_stream, net_args);
|
||||||
#else
|
#else
|
||||||
std::shared_ptr<stream> transpose_stream;
|
|
||||||
transpose_stream.reset(new CPU_STREAM(cpu_engine));
|
transpose_stream.reset(new CPU_STREAM(cpu_engine));
|
||||||
net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()));
|
net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()));
|
||||||
transpose_stream->submit(net).wait();
|
transpose_stream->submit(net).wait();
|
||||||
|
85
tensorflow/core/kernels/mlir_generated_op_gpu_tanh_test.cc
Normal file
85
tensorflow/core/kernels/mlir_generated_op_gpu_tanh_test.cc
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class MlirGeneratedOpGpuTanhTest : public OpsTestBase {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
std::unique_ptr<tensorflow::Device> device_gpu(
|
||||||
|
tensorflow::DeviceFactory::NewDevice("GPU", {},
|
||||||
|
"/job:a/replica:0/task:0"));
|
||||||
|
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
|
||||||
|
}
|
||||||
|
template <typename T, typename RT = T>
|
||||||
|
void RunTanhOp(std::initializer_list<T> input) {
|
||||||
|
TensorShape shape({2, 7});
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("tanh_op", "Tanh")
|
||||||
|
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||||
|
.Attr("T", DataTypeToEnum<T>::v())
|
||||||
|
.Finalize(node_def()));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
AddInputFromArray<T>(shape, input);
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
|
||||||
|
std::vector<T> expected;
|
||||||
|
expected.reserve(input.size());
|
||||||
|
for (const T& inp : input) {
|
||||||
|
expected.push_back(static_cast<T>(std::tanh(static_cast<RT>(inp))));
|
||||||
|
}
|
||||||
|
test::FillValues<T>(&expected_tensor, expected);
|
||||||
|
test::ExpectClose(expected_tensor, *GetOutput(0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MlirGeneratedOpGpuTanhTest, TanhFloat) {
|
||||||
|
RunTanhOp<float>({-18.0f, -9.0f, -1e-6f, -0.0f, 0.0f, 1e-6, 0.1f, 0.2f, 0.3f,
|
||||||
|
0.5f, 0.7f, 0.9f, 9.0f, 18.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MlirGeneratedOpGpuTanhTest, TanhDouble) {
|
||||||
|
RunTanhOp<double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5,
|
||||||
|
0.7, 0.9, 9.0, 18.0});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MlirGeneratedOpGpuTanhTest, TanhHalf) {
|
||||||
|
RunTanhOp<Eigen::half, float>(
|
||||||
|
{static_cast<Eigen::half>(-18.0), static_cast<Eigen::half>(-9.0),
|
||||||
|
static_cast<Eigen::half>(-1e-6), static_cast<Eigen::half>(-0.0),
|
||||||
|
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(1e-6),
|
||||||
|
static_cast<Eigen::half>(0.1), static_cast<Eigen::half>(0.2),
|
||||||
|
static_cast<Eigen::half>(0.3), static_cast<Eigen::half>(0.5),
|
||||||
|
static_cast<Eigen::half>(0.7), static_cast<Eigen::half>(0.9),
|
||||||
|
static_cast<Eigen::half>(9.0), static_cast<Eigen::half>(18.0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // end namespace tensorflow
|
@ -296,8 +296,6 @@ TF_CALL_tstring(REGISTER_CPU_KERNEL);
|
|||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
|
||||||
TF_CALL_quint16(REGISTER_CPU_KERNEL);
|
TF_CALL_quint16(REGISTER_CPU_KERNEL);
|
||||||
TF_CALL_qint16(REGISTER_CPU_KERNEL);
|
TF_CALL_qint16(REGISTER_CPU_KERNEL);
|
||||||
TF_CALL_uint32(REGISTER_CPU_KERNEL);
|
|
||||||
TF_CALL_uint64(REGISTER_CPU_KERNEL);
|
|
||||||
#undef REGISTER_CPU_KERNEL
|
#undef REGISTER_CPU_KERNEL
|
||||||
#undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE
|
#undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE
|
||||||
|
|
||||||
|
@ -308,8 +308,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
|
|||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||||
TF_CALL_quint16(REGISTER_KERNELS);
|
TF_CALL_quint16(REGISTER_KERNELS);
|
||||||
TF_CALL_qint16(REGISTER_KERNELS);
|
TF_CALL_qint16(REGISTER_KERNELS);
|
||||||
TF_CALL_uint32(REGISTER_KERNELS);
|
|
||||||
TF_CALL_uint64(REGISTER_KERNELS);
|
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
|
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -561,8 +561,6 @@ TF_CALL_string(REGISTER_CPU_KERNEL);
|
|||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
|
||||||
TF_CALL_quint16(REGISTER_CPU_KERNEL);
|
TF_CALL_quint16(REGISTER_CPU_KERNEL);
|
||||||
TF_CALL_qint16(REGISTER_CPU_KERNEL);
|
TF_CALL_qint16(REGISTER_CPU_KERNEL);
|
||||||
TF_CALL_uint32(REGISTER_CPU_KERNEL);
|
|
||||||
TF_CALL_uint64(REGISTER_CPU_KERNEL);
|
|
||||||
|
|
||||||
#undef REGISTER_CPU_KERNEL
|
#undef REGISTER_CPU_KERNEL
|
||||||
|
|
||||||
|
@ -213,8 +213,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
|
|||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||||
TF_CALL_quint16(REGISTER_KERNELS);
|
TF_CALL_quint16(REGISTER_KERNELS);
|
||||||
TF_CALL_qint16(REGISTER_KERNELS);
|
TF_CALL_qint16(REGISTER_KERNELS);
|
||||||
TF_CALL_uint32(REGISTER_KERNELS);
|
|
||||||
TF_CALL_uint64(REGISTER_KERNELS);
|
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
|
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -35,6 +35,7 @@ namespace tensorflow {
|
|||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
static constexpr int VectorSizeElements = 8;
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
// This kernel computes ReluGrad by processing one half2, two fp16, at a time.
|
// This kernel computes ReluGrad by processing one half2, two fp16, at a time.
|
||||||
@ -93,6 +94,66 @@ __global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void ReluGradHalfKernelVector(
|
||||||
|
const Eigen::half* __restrict__ gradient,
|
||||||
|
const Eigen::half* __restrict__ feature, Eigen::half* __restrict__ backprop,
|
||||||
|
int32 count) {
|
||||||
|
int32 half8_count = count / VectorSizeElements;
|
||||||
|
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (index < half8_count) {
|
||||||
|
// Cast to xx_h8 for vector load and store.
|
||||||
|
float4 gradient_h8 = reinterpret_cast<const float4*>(gradient)[index];
|
||||||
|
float4 feature_h8 = reinterpret_cast<const float4*>(feature)[index];
|
||||||
|
float4* p_backprop_h8 = reinterpret_cast<float4*>(backprop) + index;
|
||||||
|
|
||||||
|
half2* gradient_h2 = reinterpret_cast<half2*>(&gradient_h8);
|
||||||
|
half2* feature_h2 = reinterpret_cast<half2*>(&feature_h8);
|
||||||
|
float4 backprop_h8;
|
||||||
|
half2* p_backprop_h2 = reinterpret_cast<half2*>(&backprop_h8);
|
||||||
|
|
||||||
|
// Fast path, when half2 primitives are available.
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
const half2 kZeroH2 = __float2half2_rn(0.f);
|
||||||
|
#endif
|
||||||
|
for (int i = 0; i < VectorSizeElements / 2; i++) {
|
||||||
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
// mask = (feature > 0)
|
||||||
|
half2 mask_h2 = __hgt2(feature_h2[i], kZeroH2);
|
||||||
|
// backprop = mask * gradient
|
||||||
|
half2 backprop_h2 = __hmul2(mask_h2, gradient_h2[i]);
|
||||||
|
#else
|
||||||
|
// Fall back: convert half2 to float2 for processing.
|
||||||
|
float2 feature_f2 = __half22float2(feature_h2[i]);
|
||||||
|
float2 gradient_f2 = __half22float2(gradient_h2[i]);
|
||||||
|
float2 backprop_f2 =
|
||||||
|
make_float2((feature_f2.x > 0.0f) ? float(gradient_f2.x) : 0.0f,
|
||||||
|
(feature_f2.y > 0.0f) ? float(gradient_f2.y) : 0.0f);
|
||||||
|
// Convert back to half2.
|
||||||
|
half2 backprop_h2 = __float22half2_rn(backprop_f2);
|
||||||
|
#endif
|
||||||
|
p_backprop_h2[i] = backprop_h2;
|
||||||
|
}
|
||||||
|
// Write back the result.
|
||||||
|
*p_backprop_h8 = backprop_h8;
|
||||||
|
}
|
||||||
|
|
||||||
|
int remaining_count = (count % VectorSizeElements);
|
||||||
|
|
||||||
|
if (index < remaining_count) {
|
||||||
|
// Use first threads to process the remaining elements.
|
||||||
|
Eigen::half grad_h = gradient[half8_count * VectorSizeElements + index];
|
||||||
|
Eigen::half feature_h = feature[half8_count * VectorSizeElements + index];
|
||||||
|
|
||||||
|
float grad_f = static_cast<float>(grad_h);
|
||||||
|
float feature_f = static_cast<float>(feature_h);
|
||||||
|
float backprop_f = (feature_f > 0) ? grad_f : 0;
|
||||||
|
|
||||||
|
Eigen::half backprop_h(backprop_f);
|
||||||
|
backprop[half8_count * VectorSizeElements + index] = backprop_h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Device>
|
template <typename Device>
|
||||||
struct ReluGrad<Device, Eigen::half> {
|
struct ReluGrad<Device, Eigen::half> {
|
||||||
// Computes ReluGrad backprop.
|
// Computes ReluGrad backprop.
|
||||||
@ -108,16 +169,29 @@ struct ReluGrad<Device, Eigen::half> {
|
|||||||
// NOTE: When the activation is exactly zero, we do not propagate the
|
// NOTE: When the activation is exactly zero, we do not propagate the
|
||||||
// associated gradient value. This allows the output of the Relu to be used,
|
// associated gradient value. This allows the output of the Relu to be used,
|
||||||
// as well as its input.
|
// as well as its input.
|
||||||
|
auto gradient_ptr = reinterpret_cast<uintptr_t>(gradient.data());
|
||||||
|
auto feature_ptr = reinterpret_cast<uintptr_t>(feature.data());
|
||||||
|
auto backprop_ptr = reinterpret_cast<uintptr_t>(backprop.data());
|
||||||
|
bool aligned = gradient_ptr % 16 == 0 && feature_ptr % 16 == 0 &&
|
||||||
|
backprop_ptr % 16 == 0;
|
||||||
int32 count = gradient.size();
|
int32 count = gradient.size();
|
||||||
if (count == 0) return;
|
|
||||||
int32 half2_count = Eigen::divup(count, 2);
|
|
||||||
constexpr int32 kThreadInBlock = 512;
|
constexpr int32 kThreadInBlock = 512;
|
||||||
|
if (count == 0) return;
|
||||||
|
if (aligned) {
|
||||||
|
int32 half8_count = Eigen::divup(count, VectorSizeElements);
|
||||||
|
int32 kBlock = Eigen::divup(half8_count, kThreadInBlock);
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
|
ReluGradHalfKernelVector, kBlock, kThreadInBlock, 0, d.stream(),
|
||||||
|
gradient.data(), feature.data(), backprop.data(), count));
|
||||||
|
} else {
|
||||||
|
int32 half2_count = Eigen::divup(count, 2);
|
||||||
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
||||||
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
|
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
|
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
|
||||||
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
|
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
__global__ void Relu_int8x4_kernel(int vect_count,
|
__global__ void Relu_int8x4_kernel(int vect_count,
|
||||||
|
@ -512,7 +512,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
|
|||||||
|
|
||||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||||
TF_CALL_uint32(REGISTER_KERNELS);
|
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
@ -43,7 +43,6 @@ void Split<Eigen::ThreadPoolDevice, T, NDims>::operator()(
|
|||||||
|
|
||||||
TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
|
TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
|
||||||
DEFINE_CPU_KERNELS(quint8)
|
DEFINE_CPU_KERNELS(quint8)
|
||||||
DEFINE_CPU_KERNELS(uint64)
|
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
template <typename T, int NDims>
|
template <typename T, int NDims>
|
||||||
|
@ -404,7 +404,6 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
|
|||||||
|
|
||||||
TF_CALL_ALL_TYPES(REGISTER_SPLIT);
|
TF_CALL_ALL_TYPES(REGISTER_SPLIT);
|
||||||
REGISTER_SPLIT(quint8);
|
REGISTER_SPLIT(quint8);
|
||||||
REGISTER_SPLIT(uint64);
|
|
||||||
|
|
||||||
#undef REGISTER_SPLIT
|
#undef REGISTER_SPLIT
|
||||||
|
|
||||||
|
@ -440,8 +440,6 @@ class StridedSliceAssignOp : public OpKernel {
|
|||||||
StridedSliceAssignOp<CPUDevice, type, true>)
|
StridedSliceAssignOp<CPUDevice, type, true>)
|
||||||
|
|
||||||
TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
|
TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
|
||||||
TF_CALL_uint32(REGISTER_STRIDED_SLICE);
|
|
||||||
TF_CALL_uint64(REGISTER_STRIDED_SLICE);
|
|
||||||
|
|
||||||
#undef REGISTER_STRIDED_SLICE
|
#undef REGISTER_STRIDED_SLICE
|
||||||
|
|
||||||
|
@ -287,8 +287,6 @@ TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU);
|
|||||||
#endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
|
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
|
||||||
TF_CALL_uint32(DECLARE_FOR_N_CPU);
|
|
||||||
TF_CALL_uint64(DECLARE_FOR_N_CPU);
|
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
#define PREVENT_FOR_N_SYCL(T) \
|
#define PREVENT_FOR_N_SYCL(T) \
|
||||||
|
@ -52,7 +52,8 @@ class SummaryScalarOp : public OpKernel {
|
|||||||
Summary s;
|
Summary s;
|
||||||
for (int i = 0; i < Ttags.size(); i++) {
|
for (int i = 0; i < Ttags.size(); i++) {
|
||||||
Summary::Value* v = s.add_value();
|
Summary::Value* v = s.add_value();
|
||||||
v->set_tag(string(Ttags(i))); // NOLINT
|
const tstring& Ttags_i = Ttags(i);
|
||||||
|
v->set_tag(Ttags_i.data(), Ttags_i.size());
|
||||||
v->set_simple_value(float(Tvalues(i)));
|
v->set_simple_value(float(Tvalues(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,7 +103,8 @@ class SummaryHistoOp : public OpKernel {
|
|||||||
|
|
||||||
Summary s;
|
Summary s;
|
||||||
Summary::Value* v = s.add_value();
|
Summary::Value* v = s.add_value();
|
||||||
v->set_tag(string(tags.scalar<tstring>()())); // NOLINT
|
const tstring& tags0 = tags.scalar<tstring>()();
|
||||||
|
v->set_tag(tags0.data(), tags0.size());
|
||||||
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
|
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
|
||||||
|
|
||||||
Tensor* summary_tensor = nullptr;
|
Tensor* summary_tensor = nullptr;
|
||||||
|
@ -258,7 +258,6 @@ namespace functor {
|
|||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||||
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
|
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
|
||||||
TF_CALL_uint32(DECLARE_GPU_SPEC);
|
|
||||||
|
|
||||||
#undef DECLARE_GPU_SPEC
|
#undef DECLARE_GPU_SPEC
|
||||||
|
|
||||||
@ -276,7 +275,6 @@ TF_CALL_uint32(DECLARE_GPU_SPEC);
|
|||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
||||||
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
|
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
|
||||||
TF_CALL_uint32(REGISTER_KERNELS)
|
|
||||||
#undef REGISTER_KERNELS
|
#undef REGISTER_KERNELS
|
||||||
|
|
||||||
#endif // end GOOGLE_CUDA
|
#endif // end GOOGLE_CUDA
|
||||||
|
28
tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc
Normal file
28
tensorflow/core/kernels/topk_op_gpu_uint32.cu.cc
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/topk_op.h"
|
||||||
|
#include "tensorflow/core/kernels/topk_op_gpu.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
using Eigen::GpuDevice;
|
||||||
|
|
||||||
|
template struct functor::TopKFunctor<GPUDevice, uint32>;
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
28
tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc
Normal file
28
tensorflow/core/kernels/topk_op_gpu_uint64.cu.cc
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/topk_op.h"
|
||||||
|
#include "tensorflow/core/kernels/topk_op_gpu.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
using Eigen::GpuDevice;
|
||||||
|
|
||||||
|
template struct functor::TopKFunctor<GPUDevice, uint64>;
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
@ -252,7 +252,8 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
|
|||||||
|
|
||||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||||
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
||||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||||
|
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
|
||||||
#undef REGISTER_GPU_KERNELS
|
#undef REGISTER_GPU_KERNELS
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
@ -78,3 +78,32 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "Acos"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_BFLOAT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -78,3 +78,32 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "Asin"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_BFLOAT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -78,3 +78,32 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "Atan"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_BFLOAT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -168,3 +168,32 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "Inv"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_BFLOAT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user