Merge branch 'master' into master
This commit is contained in:
commit
fccd345ab3
|
@ -24,9 +24,21 @@ cc_library(
|
|||
"//tensorflow:windows": get_win_copts(),
|
||||
}),
|
||||
deps = [
|
||||
":gcs_helper",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gcs_helper",
|
||||
srcs = ["gcs_helper.cc"],
|
||||
hdrs = ["gcs_helper.h"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//tensorflow/c:env",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,9 +15,13 @@ limitations under the License.
|
|||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
|
@ -86,6 +90,20 @@ namespace tf_random_access_file {
|
|||
// SECTION 2. Implementation for `TF_WritableFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_writable_file {
|
||||
typedef struct GCSFile {
|
||||
const char* bucket;
|
||||
const char* object;
|
||||
gcs::Client* gcs_client; // not owned
|
||||
TempFile outfile;
|
||||
bool sync_need;
|
||||
} GCSFile;
|
||||
|
||||
static void Cleanup(TF_WritableFile* file) {
|
||||
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
|
||||
plugin_memory_free(const_cast<char*>(gcs_file->bucket));
|
||||
plugin_memory_free(const_cast<char*>(gcs_file->object));
|
||||
delete gcs_file;
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
|
@ -119,6 +137,20 @@ static void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
|||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
static void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
char* bucket;
|
||||
char* object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
|
||||
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
|
||||
TempFile outfile(TF_GetTempFileName(""), std::ios::binary | std::ios::out);
|
||||
file->plugin_file = new tf_writable_file::GCSFile(
|
||||
{bucket, object, gcs_client, std::move(outfile), true});
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
|
@ -126,9 +158,14 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
|||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->writable_file_ops = static_cast<TF_WritableFileOps*>(
|
||||
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
|
||||
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
|
||||
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
TempFile::TempFile(const char* temp_file_name, std::ios::openmode mode)
|
||||
: std::fstream(temp_file_name, mode), name_(temp_file_name) {}
|
||||
|
||||
TempFile::TempFile(TempFile&& rhs)
|
||||
: std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {}
|
||||
|
||||
TempFile::~TempFile() {
|
||||
std::fstream::close();
|
||||
std::remove(name_.c_str());
|
||||
}
|
||||
|
||||
const std::string TempFile::getName() const { return name_; }
|
|
@ -0,0 +1,33 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
class TempFile : public std::fstream {
|
||||
public:
|
||||
// We should specify openmode each time we call TempFile.
|
||||
TempFile(const char* temp_file_name, std::ios::openmode mode);
|
||||
TempFile(TempFile&& rhs);
|
||||
~TempFile() override;
|
||||
const std::string getName() const;
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
};
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_
|
|
@ -314,7 +314,6 @@ tf_cc_test(
|
|||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
"transforms/device_index_selector.cc",
|
||||
"transforms/dilated_conv.cc",
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
|
|
|
@ -953,14 +953,14 @@ in the batch dimensions and broadcasting.
|
|||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TFL_TensorOf<[F32]>:$x,
|
||||
TFL_TensorOf<[F32]>:$y,
|
||||
TFL_TensorOf<[F32, QI8]>:$x,
|
||||
TFL_TensorOf<[F32, QI8]>:$y,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TFL_TensorOf<[F32]>:$output
|
||||
TFL_TensorOf<[F32, QI8]>:$output
|
||||
);
|
||||
|
||||
let hasOptions = 1;
|
||||
|
|
|
@ -63,7 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
|||
standard_pipeline_options.enable_inliner = false;
|
||||
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
||||
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
||||
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
|
||||
pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass());
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
|
|
|
@ -91,9 +91,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
|||
// Verifies runtime constraints.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||
|
||||
// Creates function pass to select device index/fold tf.DeviceIndex.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -475,6 +475,7 @@ cc_library(
|
|||
"transforms/cluster_outlining.cc",
|
||||
"transforms/collection_ops_util.cc",
|
||||
"transforms/decompose_resource_ops_pass.cc",
|
||||
"transforms/device_index_selector.cc",
|
||||
"transforms/einsum.cc",
|
||||
"transforms/executor_island_coarsening.cc",
|
||||
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
||||
|
|
|
@ -164,6 +164,81 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> {
|
||||
let summary = "Adjust the contrast of one or more images.";
|
||||
|
||||
let description = [{
|
||||
`images` is a tensor of at least 3 dimensions. The last 3 dimensions are
|
||||
interpreted as `[height, width, channels]`. The other dimensions only
|
||||
represent a collection of images, such as `[batch, height, width, channels].`
|
||||
|
||||
Contrast is adjusted independently for each channel of each image.
|
||||
|
||||
For each channel, the Op first computes the mean of the image pixels in the
|
||||
channel and then adjusts each component of each pixel to
|
||||
`(x - mean) * contrast_factor + mean`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32]>:$images,
|
||||
F32Tensor:$contrast_factor
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AdjustHueOp : TF_Op<"AdjustHue", [NoSideEffect]> {
|
||||
let summary = "Adjust the hue of one or more images.";
|
||||
|
||||
let description = [{
|
||||
`images` is a tensor of at least 3 dimensions. The last dimension is
|
||||
interpreted as channels, and must be three.
|
||||
|
||||
The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||
colors are first mapped into HSV. A delta is then applied all the hue values,
|
||||
and then remapped back to RGB colorspace.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32]>:$images,
|
||||
F32Tensor:$delta
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AdjustSaturationOp : TF_Op<"AdjustSaturation", [NoSideEffect]> {
|
||||
let summary = "Adjust the saturation of one or more images.";
|
||||
|
||||
let description = [{
|
||||
`images` is a tensor of at least 3 dimensions. The last dimension is
|
||||
interpreted as channels, and must be three.
|
||||
|
||||
The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||
colors are first mapped into HSV. A scale is then applied all the saturation
|
||||
values, and then remapped back to RGB colorspace.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F16, F32]>:$images,
|
||||
F32Tensor:$scale
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[F16, F32]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Computes the "logical and" of elements across dimensions of a tensor.
|
||||
|
@ -3866,6 +3941,28 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True]
|
|||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_HSVToRGBOp : TF_Op<"HSVToRGB", [NoSideEffect]> {
|
||||
let summary = "Convert one or more images from HSV to RGB.";
|
||||
|
||||
let description = [{
|
||||
Outputs a tensor of the same shape as the `images` tensor, containing the RGB
|
||||
value of the pixels. The output is only well defined if the value in `images`
|
||||
are in `[0,1]`.
|
||||
|
||||
See `rgb_to_hsv` for a description of the HSV encoding.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$images
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_HashTableV2Op : TF_Op<"HashTableV2", []> {
|
||||
let summary = "Creates a non-initialized hash table.";
|
||||
|
||||
|
@ -5962,11 +6059,11 @@ I.e., \\(y = -x\\).
|
|||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
@ -6733,6 +6830,41 @@ the dimension is padded with zeros.
|
|||
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RGBToHSVOp : TF_Op<"RGBToHSV", [NoSideEffect]> {
|
||||
let summary = "Converts one or more images from RGB to HSV.";
|
||||
|
||||
let description = [{
|
||||
Outputs a tensor of the same shape as the `images` tensor, containing the HSV
|
||||
value of the pixels. The output is only well defined if the value in `images`
|
||||
are in `[0,1]`.
|
||||
|
||||
`output[..., 0]` contains hue, `output[..., 1]` contains saturation, and
|
||||
`output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0
|
||||
corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue.
|
||||
|
||||
Usage Example:
|
||||
|
||||
>>> blue_image = tf.stack([
|
||||
... tf.zeros([5,5]),
|
||||
... tf.zeros([5,5]),
|
||||
... tf.ones([5,5])],
|
||||
... axis=-1)
|
||||
>>> blue_hsv_image = tf.image.rgb_to_hsv(blue_image)
|
||||
>>> blue_hsv_image[0,0].numpy()
|
||||
array([0.6666667, 1. , 1. ], dtype=float32)
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_FpTensor:$images
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = [{
|
||||
|
@ -7230,6 +7362,27 @@ Input images can be of different types but output images are always float.
|
|||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ResizeBilinearGradOp : TF_Op<"ResizeBilinearGrad", [NoSideEffect]> {
|
||||
let summary = "Computes the gradient of bilinear interpolation.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$grads,
|
||||
TF_FpTensor:$original_image,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$align_corners,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$half_pixel_centers
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_FpTensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_ResizeNearestNeighborOp : TF_Op<"ResizeNearestNeighbor", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Resize `images` to `size` using nearest neighbor interpolation.
|
||||
|
|
|
@ -21,11 +21,11 @@ limitations under the License.
|
|||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace TF {
|
||||
namespace {
|
||||
|
||||
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the
|
||||
|
@ -55,8 +55,8 @@ void DeviceIndexSelector::runOnOperation() {
|
|||
// Convert all the DeviceIndex ops to constant values.
|
||||
func.getBody().walk([](TF::DeviceIndexOp op) {
|
||||
// This just selects the default in all cases where DeviceIndex feeds into
|
||||
// tf.Case. This could be enhanced based on explicit TFLite specification or
|
||||
// TAC in future.
|
||||
// tf.Case. This could be enhanced to have some sort of policy in the
|
||||
// future.
|
||||
OpBuilder b(op);
|
||||
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
|
||||
int index = op.device_names().size();
|
||||
|
@ -79,7 +79,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass() {
|
|||
}
|
||||
|
||||
static PassRegistration<DeviceIndexSelector> pass(
|
||||
"tfl-device-index-selector", "Fold tf.DeviceIndex to constant");
|
||||
"tf-device-index-selector", "Fold tf.DeviceIndex to constant");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
|
@ -52,11 +52,6 @@ struct FusedKernelMatcherPass
|
|||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Returns an op's name with the dialect prefix stripped off.
|
||||
StringRef GetOpNameWithoutDialect(Operation *op) {
|
||||
return op->getName().getStringRef().split(".").second;
|
||||
}
|
||||
|
||||
bool IsActivationFunction(Operation *op) {
|
||||
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
|
||||
}
|
||||
|
@ -128,8 +123,8 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
|
|||
}
|
||||
|
||||
SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
|
||||
SmallVector<Attribute, 2> fused_ops{
|
||||
StringAttr::get(GetOpNameWithoutDialect(bias_add), context)};
|
||||
SmallVector<Attribute, 2> fused_ops{StringAttr::get(
|
||||
bias_add.getOperation()->getName().stripDialect(), context)};
|
||||
|
||||
// BiasAdd may or may not feed into an activation function.
|
||||
auto activation = GetActivation(bias_add);
|
||||
|
@ -143,7 +138,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
|
|||
if (fuse_activation) {
|
||||
locations.push_back(activation->getLoc());
|
||||
fused_ops.push_back(
|
||||
StringAttr::get(GetOpNameWithoutDialect(activation), context));
|
||||
StringAttr::get(activation->getName().stripDialect(), context));
|
||||
result_type = activation->getResultTypes().front();
|
||||
} else {
|
||||
result_type = bias_add.getResult().getType();
|
||||
|
|
|
@ -147,6 +147,9 @@ std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
|
|||
// generally used beyond exporting to runtimes that supports these ops. In the
|
||||
// future these fusions may be codegen'd automatically.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
|
||||
|
||||
// Creates function pass to select device index/fold tf.DeviceIndex.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace tf_executor {
|
||||
|
|
|
@ -106,53 +106,6 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
|
|||
return GetI64ElementsAttr(slice_limits, builder);
|
||||
}
|
||||
|
||||
// Returns the padding value of the given position. If padding_attr is a
|
||||
// nullptr, returns 0.
|
||||
static int64_t GetPaddingValue(DenseIntElementsAttr padding_attr,
|
||||
ArrayRef<uint64_t> index) {
|
||||
if (!padding_attr) return 0;
|
||||
return padding_attr.getValue<int64_t>(index);
|
||||
}
|
||||
|
||||
static bool IsOnlyPaddingSpatialDims(Value lhs,
|
||||
ConvDimensionNumbers dimension_numbers,
|
||||
DenseIntElementsAttr edge_padding_low,
|
||||
DenseIntElementsAttr edge_padding_high) {
|
||||
const int64_t batch_dim = dimension_numbers.input_batch_dimension().getInt();
|
||||
const int64_t feature_dim =
|
||||
dimension_numbers.input_feature_dimension().getInt();
|
||||
if (edge_padding_low.getValue<int64_t>(batch_dim) ||
|
||||
edge_padding_high.getValue<int64_t>(batch_dim))
|
||||
return false;
|
||||
if (edge_padding_low.getValue<int64_t>(feature_dim) ||
|
||||
edge_padding_high.getValue<int64_t>(feature_dim))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
DenseIntElementsAttr BuildConvPaddingAttrs(
|
||||
DenseIntElementsAttr edge_padding_low,
|
||||
DenseIntElementsAttr edge_padding_high, DenseIntElementsAttr padding_attr,
|
||||
ConvDimensionNumbers dimension_numbers, Builder* builder) {
|
||||
SmallVector<int64_t, 4> padding_low, padding_high;
|
||||
for (const auto& dim : dimension_numbers.input_spatial_dimensions()) {
|
||||
unsigned i = dim.getZExtValue();
|
||||
padding_low.push_back(edge_padding_low.getValue<int64_t>(i));
|
||||
padding_high.push_back(edge_padding_high.getValue<int64_t>(i));
|
||||
}
|
||||
|
||||
int rank = padding_low.size();
|
||||
SmallVector<int64_t, 8> padding;
|
||||
for (unsigned i = 0, e = rank; i < e; ++i) {
|
||||
padding.push_back(GetPaddingValue(padding_attr, {i, 0}) + padding_low[i]);
|
||||
padding.push_back(GetPaddingValue(padding_attr, {i, 1}) + padding_high[i]);
|
||||
}
|
||||
// padding_attr.getType() doesn't work because it is an optional attribute,
|
||||
// which can be a nullptr.
|
||||
auto type = RankedTensorType::get({rank, 2}, builder->getIntegerType(64));
|
||||
return DenseIntElementsAttr::get(type, padding);
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/generated_canonicalize.inc"
|
||||
} // namespace
|
||||
|
||||
|
@ -2153,14 +2106,5 @@ LogicalResult deriveShapeFromFirstOperand(
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ConvOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
MLIRContext* context) {
|
||||
results.insert<FoldPadIntoConv>(context);
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -929,8 +929,6 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
|||
);
|
||||
|
||||
let results = (outs HLO_Tensor);
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {
|
||||
|
|
|
@ -132,6 +132,22 @@ StatusOr<XlaOp> MlirHloBuilder::FftInternal(
|
|||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
||||
if (operand_shapes_with_layout.has_value())
|
||||
return Unimplemented(
|
||||
"CustomCall doesn't support operands shapes with layout");
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::CustomCallOp>(
|
||||
loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name),
|
||||
/*has_side_effect=*/builder_.getBoolAttr(false),
|
||||
builder_.getStringAttr(opaque));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
|
|
|
@ -124,6 +124,12 @@ class MlirHloBuilder : public XlaBuilder {
|
|||
FftType fft_type,
|
||||
absl::Span<const int64> fft_length) override;
|
||||
|
||||
StatusOr<XlaOp> CustomCallInternal(const string& call_target_name,
|
||||
absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>>
|
||||
operand_shapes_with_layout) override;
|
||||
|
||||
StatusOr<XlaOp> ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
|
|
|
@ -415,71 +415,6 @@ func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
|
|||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_pad_into_conv_f32
|
||||
func @fold_pad_into_conv_f32(%arg0 : tensor<1x32x32x3xf32>,
|
||||
%arg1 : tensor<7x7x3x64xf32>)
|
||||
-> tensor<1x16x16x64xf32> {
|
||||
// CHECK-NOT: xla_hlo.pad
|
||||
// CHECK: xla_hlo.convolution
|
||||
// CHECK-SAME: padding = dense<3> : tensor<2x2xi64>
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_hlo.pad"(%arg0, %0) {
|
||||
edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>,
|
||||
edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>,
|
||||
interior_padding = dense<0> : tensor<4xi64>
|
||||
} : (tensor<1x32x32x3xf32>, tensor<f32>) -> tensor<1x38x38x3xf32>
|
||||
%2 = "xla_hlo.convolution"(%1, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
padding = dense<0> : tensor<2x2xi64>,
|
||||
window_strides = dense<2> : tensor<2xi64>
|
||||
} : (tensor<1x38x38x3xf32>, tensor<7x7x3x64xf32>) -> tensor<1x16x16x64xf32>
|
||||
return %2 : tensor<1x16x16x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_pad_into_conv_i32
|
||||
func @fold_pad_into_conv_i32(%arg0 : tensor<1x32x32x3xi32>,
|
||||
%arg1 : tensor<7x7x3x64xi32>)
|
||||
-> tensor<1x16x16x64xi32> {
|
||||
// CHECK-NOT: xla_hlo.pad
|
||||
// CHECK: xla_hlo.convolution
|
||||
// CHECK-SAME: padding = dense<3> : tensor<2x2xi64>
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_hlo.pad"(%arg0, %0) {
|
||||
edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>,
|
||||
edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>,
|
||||
interior_padding = dense<0> : tensor<4xi64>
|
||||
} : (tensor<1x32x32x3xi32>, tensor<i32>) -> tensor<1x38x38x3xi32>
|
||||
%2 = "xla_hlo.convolution"(%1, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
window_strides = dense<2> : tensor<2xi64>
|
||||
} : (tensor<1x38x38x3xi32>, tensor<7x7x3x64xi32>) -> tensor<1x16x16x64xi32>
|
||||
return %2 : tensor<1x16x16x64xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
|
||||
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
|
||||
// CHECK: xla_hlo.reshape
|
||||
|
|
|
@ -35,7 +35,7 @@ func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor
|
|||
// CHECK-LABEL: unranked_operand
|
||||
func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: tf.Abs
|
||||
// expected-remark@+1 {{lowering requires static shaped operands}}
|
||||
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
|
||||
return %0 : tensor<*xf32>
|
||||
|
@ -44,12 +44,20 @@ func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
|||
// CHECK-LABEL: dynamic_operand
|
||||
func @dynamic_operand(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: tf.Abs
|
||||
// expected-remark@+1 {{lowering requires static shaped operands}}
|
||||
// expected-remark@+1 {{lowering requires static shaped tensor operands}}
|
||||
%0 = "tf.Abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: tuple_type
|
||||
func @tuple_type(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
|
||||
// Verifies that the pass can handle operands of non-tensor type like tuple
|
||||
// from non TensorFlow ops.
|
||||
%0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple<tensor<f32>, tensor<i32>>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: unsupported_dtype
|
||||
func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> {
|
||||
// CHECK: tf.AddN
|
||||
|
|
|
@ -28,54 +28,3 @@ def UnaryEinsumToEinsum : Pat<
|
|||
(HLO_UnaryEinsumOp $operand, $equation),
|
||||
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
|
||||
$operand, (UnaryToBinaryEinsumEq $equation))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conv op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def IsZero : Attr<CPred<
|
||||
"($_self.isa<DenseFPElementsAttr>() &&"
|
||||
"$_self.cast<DenseFPElementsAttr>().isSplat() &&"
|
||||
"$_self.cast<DenseFPElementsAttr>().getSplatValue<FloatAttr>()"
|
||||
".getValue().isZero()) ||"
|
||||
"($_self.isa<DenseIntElementsAttr>() &&"
|
||||
"$_self.cast<DenseIntElementsAttr>().isSplat() &&"
|
||||
"$_self.cast<DenseIntElementsAttr>().getSplatValue<IntegerAttr>()"
|
||||
".getInt() == 0)">>;
|
||||
|
||||
def IsOnlyPaddingSpatialDims
|
||||
: Constraint<CPred<"IsOnlyPaddingSpatialDims($0, $1, $2, $3)">>;
|
||||
|
||||
def BuildConvPaddingAttrs : NativeCodeCall<
|
||||
"BuildConvPaddingAttrs($0, $1, $2, $3, &$_builder)">;
|
||||
|
||||
def FoldPadIntoConv : Pat<
|
||||
(HLO_ConvOp
|
||||
(HLO_PadOp $lhs,
|
||||
(HLO_ConstOp IsZero:$padding_value),
|
||||
$edge_padding_low,
|
||||
$edge_padding_high,
|
||||
IsZero:$interior_padding),
|
||||
$rhs,
|
||||
$window_strides,
|
||||
$padding,
|
||||
$lhs_dilation,
|
||||
$rhs_dilation,
|
||||
$dimension_numbers,
|
||||
$feature_group_count,
|
||||
$batch_group_count,
|
||||
$precision_config),
|
||||
(HLO_ConvOp
|
||||
$lhs,
|
||||
$rhs,
|
||||
$window_strides,
|
||||
(BuildConvPaddingAttrs $edge_padding_low, $edge_padding_high, $padding,
|
||||
$dimension_numbers),
|
||||
$lhs_dilation,
|
||||
$rhs_dilation,
|
||||
$dimension_numbers,
|
||||
$feature_group_count,
|
||||
$batch_group_count,
|
||||
$precision_config),
|
||||
[(IsOnlyPaddingSpatialDims $lhs, $dimension_numbers, $edge_padding_low,
|
||||
$edge_padding_high)]>;
|
||||
|
|
|
@ -389,10 +389,13 @@ struct HloLegalizeToLhlo
|
|||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
target.addLegalOp<TensorFromElementsOp>();
|
||||
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
||||
|
||||
BufferAssignmentTypeConverter converter;
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
auto inputs = op.getType().getInputs();
|
||||
return std::all_of(inputs.begin(), inputs.end(),
|
||||
[](Type input) { return input.isa<MemRefType>(); });
|
||||
return llvm::all_of(inputs,
|
||||
[](Type input) { return input.isa<MemRefType>(); }) &&
|
||||
converter.isLegal(&op.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
|
||||
return std::all_of(returnOp.operand_type_begin(),
|
||||
|
@ -401,8 +404,7 @@ struct HloLegalizeToLhlo
|
|||
});
|
||||
|
||||
auto module = getOperation();
|
||||
BufferAssignmentTypeConverter converter;
|
||||
module.walk([&](FuncOp func) {
|
||||
module.walk([&](FuncOp func) -> WalkResult {
|
||||
BufferAssignmentPlacer bufferAssignment(func);
|
||||
OwningRewritePatternList patterns;
|
||||
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
|
||||
|
@ -418,8 +420,7 @@ struct HloLegalizeToLhlo
|
|||
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
|
||||
&converter, &patterns);
|
||||
}
|
||||
return WalkResult(
|
||||
applyPartialConversion(func, target, patterns, &converter));
|
||||
return applyPartialConversion(func, target, patterns);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -463,6 +464,7 @@ void populateHLOToLHLOConversionPattern(
|
|||
HloToLhloOpConverter<xla_hlo::RealOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RemOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SelectOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SignOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SqrtOp>,
|
||||
|
|
|
@ -5238,8 +5238,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
|||
// Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp.
|
||||
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ::mlir::ReturnOp>();
|
||||
DenseSet<Operation *> nonlegalized_ops;
|
||||
LogicalResult result = applyPartialConversion(
|
||||
op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops);
|
||||
LogicalResult result =
|
||||
applyPartialConversion(op, target, patterns, &nonlegalized_ops);
|
||||
// In order to enforce that the conversion result is fully converted,
|
||||
// fail if there are any nonlegalized ops in the set.
|
||||
if (failed(result) || !nonlegalized_ops.empty()) {
|
||||
|
|
|
@ -88,6 +88,9 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||
TypeID::get<TF::AddNOp>(),
|
||||
TypeID::get<TF::AddV2Op>(),
|
||||
TypeID::get<TF::AngleOp>(),
|
||||
TypeID::get<TF::AdjustContrastv2Op>(),
|
||||
TypeID::get<TF::AdjustHueOp>(),
|
||||
TypeID::get<TF::AdjustSaturationOp>(),
|
||||
TypeID::get<TF::ApproximateEqualOp>(),
|
||||
TypeID::get<TF::ArgMaxOp>(),
|
||||
TypeID::get<TF::ArgMinOp>(),
|
||||
|
@ -127,6 +130,7 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||
TypeID::get<TF::GatherNdOp>(),
|
||||
TypeID::get<TF::GreaterEqualOp>(),
|
||||
TypeID::get<TF::GreaterOp>(),
|
||||
TypeID::get<TF::HSVToRGBOp>(),
|
||||
TypeID::get<TF::IFFT2DOp>(),
|
||||
TypeID::get<TF::IFFT3DOp>(),
|
||||
TypeID::get<TF::IFFTOp>(),
|
||||
|
@ -157,10 +161,14 @@ static bool IsOpWhitelisted(Operation* op) {
|
|||
TypeID::get<TF::PowOp>(),
|
||||
TypeID::get<TF::RFFT2DOp>(),
|
||||
TypeID::get<TF::RFFT3DOp>(),
|
||||
TypeID::get<TF::RGBToHSVOp>(),
|
||||
TypeID::get<TF::RealDivOp>(),
|
||||
TypeID::get<TF::ReciprocalOp>(),
|
||||
TypeID::get<TF::ReciprocalGradOp>(),
|
||||
TypeID::get<TF::Relu6GradOp>(),
|
||||
TypeID::get<TF::ResizeBilinearOp>(),
|
||||
TypeID::get<TF::ResizeBilinearGradOp>(),
|
||||
TypeID::get<TF::ResizeNearestNeighborOp>(),
|
||||
TypeID::get<TF::ReverseSequenceOp>(),
|
||||
TypeID::get<TF::RightShiftOp>(),
|
||||
TypeID::get<TF::RintOp>(),
|
||||
|
@ -337,9 +345,9 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
|||
|
||||
// Only static shaped operands are supported in XLA builders for now.
|
||||
for (Type ty : op->getOperandTypes()) {
|
||||
auto ranked_ty = ty.cast<ShapedType>();
|
||||
if (!ranked_ty.hasStaticShape()) {
|
||||
op->emitRemark() << "lowering requires static shaped operands";
|
||||
auto ranked_ty = ty.dyn_cast<ShapedType>();
|
||||
if (!ranked_ty || !ranked_ty.hasStaticShape()) {
|
||||
op->emitRemark() << "lowering requires static shaped tensor operands";
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -177,7 +177,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
|
|||
target.addIllegalOp<ReduceOp>();
|
||||
auto func = getFunction();
|
||||
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
|
||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ class TestLhloToLLVMPass
|
|||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addIllegalDialect<XlaLhloDialect>();
|
||||
|
||||
if (failed(applyFullConversion(m, target, patterns, &converter))) {
|
||||
if (failed(applyFullConversion(m, target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -711,7 +711,7 @@ struct LhloLegalizeToParallelLoops
|
|||
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
|
||||
xla_lhlo::SelectAndScatterOp>();
|
||||
|
||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,6 +61,7 @@ MAP_HLO_TO_LHLO(MulOp);
|
|||
MAP_HLO_TO_LHLO(NegOp);
|
||||
MAP_HLO_TO_LHLO(RealOp);
|
||||
MAP_HLO_TO_LHLO(ReduceOp);
|
||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||
MAP_HLO_TO_LHLO(RemOp);
|
||||
MAP_HLO_TO_LHLO(RsqrtOp);
|
||||
MAP_HLO_TO_LHLO(SelectOp);
|
||||
|
|
|
@ -867,7 +867,7 @@ struct LhloLegalizeToLinalg
|
|||
|
||||
auto func = getFunction();
|
||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -882,7 +882,7 @@ struct HloLegalizeToLinalg
|
|||
|
||||
auto func = getFunction();
|
||||
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Test DeviceIndex selector.
|
||||
|
||||
// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s
|
||||
// RUN: tf-opt --tf-device-index-selector %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
|
|
@ -770,6 +770,7 @@ tf_xla_py_test(
|
|||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["image_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
|
|
|
@ -1564,16 +1564,12 @@ XlaOp XlaBuilder::CustomCall(
|
|||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
if (absl::StartsWith(call_target_name, "$")) {
|
||||
return InvalidArgument(
|
||||
"Invalid custom_call_target \"%s\": Call targets that start with '$' "
|
||||
"are reserved for internal use.",
|
||||
call_target_name);
|
||||
}
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_custom_call_target(call_target_name);
|
||||
instr.set_backend_config(opaque);
|
||||
if (operand_shapes_with_layout.has_value()) {
|
||||
if (!LayoutUtil::HasLayout(shape)) {
|
||||
return InvalidArgument(
|
||||
|
@ -1586,7 +1582,6 @@ XlaOp XlaBuilder::CustomCall(
|
|||
"with constrained layout; given %d shapes, expected %d",
|
||||
operand_shapes_with_layout->size(), operands.size());
|
||||
}
|
||||
instr.set_constrain_layout(true);
|
||||
int64 operand_num = 0;
|
||||
for (const Shape& operand_shape : *operand_shapes_with_layout) {
|
||||
if (!LayoutUtil::HasLayout(operand_shape)) {
|
||||
|
@ -1595,14 +1590,31 @@ XlaOp XlaBuilder::CustomCall(
|
|||
"constrained layout.",
|
||||
operand_num);
|
||||
}
|
||||
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
||||
++operand_num;
|
||||
}
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
||||
return CustomCallInternal(call_target_name, operands, shape, opaque,
|
||||
operand_shapes_with_layout);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::CustomCallInternal(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_custom_call_target(call_target_name);
|
||||
instr.set_backend_config(opaque);
|
||||
if (operand_shapes_with_layout.has_value()) {
|
||||
instr.set_constrain_layout(true);
|
||||
for (const Shape& operand_shape : *operand_shapes_with_layout) {
|
||||
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
||||
}
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::CustomCall(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation, const Shape& shape, const string& opaque,
|
||||
|
|
|
@ -527,6 +527,14 @@ class XlaBuilder {
|
|||
const Shape& shape_with_layout, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
|
||||
|
||||
// Internal version of CustomCall without computation that doesn't do op
|
||||
// specific error handling and expects arguments to be legal. CustomCall
|
||||
// method above calls this method after error handling.
|
||||
virtual StatusOr<XlaOp> CustomCallInternal(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const Shape& shape_with_layout, const string& opaque,
|
||||
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
|
||||
|
||||
XlaOp CustomCall(
|
||||
const string& call_target_name, absl::Span<const XlaOp> operands,
|
||||
const XlaComputation& computation, const Shape& shape_with_layout,
|
||||
|
|
|
@ -141,7 +141,9 @@ cc_library(
|
|||
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
|
||||
"//tensorflow/core:allocator",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/lib:connected_traceme",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/profiler/lib:traceme_encode",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:stream",
|
||||
"//tensorflow/stream_executor/host:host_platform_id",
|
||||
|
|
|
@ -127,11 +127,15 @@ std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
|
|||
} else {
|
||||
std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
|
||||
usage_stream_pool_.pop();
|
||||
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
|
||||
QCHECK(stream->ok());
|
||||
return stream;
|
||||
}
|
||||
}
|
||||
|
||||
void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
|
||||
stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented
|
||||
QCHECK(stream->ok());
|
||||
absl::MutexLock lock(&mu_);
|
||||
usage_stream_pool_.push(std::move(stream));
|
||||
}
|
||||
|
|
|
@ -98,7 +98,9 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/connected_traceme.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme_encode.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/event.h"
|
||||
|
@ -749,16 +751,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
|||
// memory that has already been allocated, and a possible Event
|
||||
// allocation.
|
||||
|
||||
se::Stream* h2d_stream = local_device->host_to_device_stream();
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||
compact_shape, on_device_shape, client->client()->platform());
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
local_device->host_to_device_stream(), literal, buffer));
|
||||
h2d_stream, literal, buffer));
|
||||
|
||||
std::shared_ptr<BufferSequencingEvent> event =
|
||||
device_buffer->definition_events()[0];
|
||||
TF_CHECK_OK(AddDestinationBufferSynchronization(
|
||||
local_device, std::move(device_buffer), event,
|
||||
local_device->host_to_device_stream()));
|
||||
local_device, std::move(device_buffer), event, h2d_stream));
|
||||
|
||||
// This can sometimes catch the case where the literal memory has been
|
||||
// freed before the H2D transfer was issued.
|
||||
h2d_stream->RefreshStatus()
|
||||
.IgnoreError(); // Can return error::Unimplemented
|
||||
QCHECK(h2d_stream->ok());
|
||||
};
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
return py_buffer;
|
||||
|
@ -1069,13 +1077,17 @@ Status PjRtBuffer::CopyToHostAsync() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral() {
|
||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||
const bool discard_cached_copy) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
|
||||
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
||||
std::shared_ptr<HostValue> host_value;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
host_value = host_value_;
|
||||
if (discard_cached_copy) {
|
||||
host_value_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (host_value == nullptr) {
|
||||
return InvalidArgument("ToLiteral called on invalid buffer");
|
||||
|
@ -1429,10 +1441,9 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
|||
int executable_idx, const RunId& run_id, const ExecuteOptions& options,
|
||||
Device* device, std::vector<PjRtBuffer::ScopedHold>* device_buffers) const {
|
||||
int device_ordinal = device->local_device_state()->device_ordinal();
|
||||
tensorflow::profiler::TraceMe traceme([&] {
|
||||
return absl::StrCat("LocalExecutable::Execute#run_id=", run_id.ToInt(),
|
||||
"#");
|
||||
});
|
||||
tensorflow::profiler::TraceMeConsumer activity(
|
||||
"LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt,
|
||||
run_id.ToInt());
|
||||
VLOG(3) << "Replica " << replica << ", partition " << partition
|
||||
<< " mapped to device ordinal for execution: " << device_ordinal;
|
||||
|
||||
|
@ -1721,10 +1732,9 @@ PjRtExecutable::ExecuteOnLocalDevices(
|
|||
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
|
||||
const ExecuteOptions& options) const {
|
||||
RunId run_id;
|
||||
tensorflow::profiler::TraceMe traceme([&] {
|
||||
return absl::StrCat(
|
||||
"LocalExecutable::ExecuteOnLocalDevices#run_id=", run_id.ToInt(), "#");
|
||||
});
|
||||
tensorflow::profiler::TraceMeProducer activity(
|
||||
"LocalExecutable::ExecuteOnLocalDevices",
|
||||
tensorflow::profiler::ContextType::kPjRt, run_id.ToInt());
|
||||
|
||||
const int num_local_devices = local_devices_.size();
|
||||
|
||||
|
|
|
@ -478,8 +478,12 @@ class PjRtBuffer {
|
|||
|
||||
// Returns the buffer's value as an XLA Literal. If the value has previously
|
||||
// been prefetched to the host, then returns the prefetched version, otherwise
|
||||
// copies the buffer to the host. Blocks until the value is ready.
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral();
|
||||
// copies the buffer to the host. Blocks until the value is ready. If
|
||||
// `discard_cached_copy` is true then buffer will no longer keep hold of a
|
||||
// cached copy of the literal (i.e. The reference to the host value will be
|
||||
// removed.)
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
||||
bool discard_cached_copy = false);
|
||||
|
||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||
// the transfer to complete. The value can be retrieved by a later call to
|
||||
|
|
|
@ -106,7 +106,6 @@ class BranchVisitor {
|
|||
boundaries_.emplace_back(operand, i, inst);
|
||||
continue;
|
||||
}
|
||||
|
||||
worklist_.push_back(operand);
|
||||
visited_.insert(operand);
|
||||
}
|
||||
|
@ -197,6 +196,7 @@ bool WorthHoisting(HloInstruction* instruction) {
|
|||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
return true;
|
||||
default:
|
||||
|
@ -206,10 +206,11 @@ bool WorthHoisting(HloInstruction* instruction) {
|
|||
|
||||
// Compare if the instructions to be visited at each branches are identical.
|
||||
bool InstructionWithinBranchIdentical(
|
||||
const std::vector<HloInstruction*>& instructions, bool is_layout_senstive) {
|
||||
const std::vector<HloInstruction*>& instructions,
|
||||
bool is_layout_sensitive) {
|
||||
// Identical includes the shape of each operands are equal.
|
||||
auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) {
|
||||
bool eq_operands = is_layout_senstive
|
||||
bool eq_operands = is_layout_sensitive
|
||||
? ShapeUtil::Equal(a->shape(), b->shape())
|
||||
: ShapeUtil::Compatible(a->shape(), b->shape());
|
||||
return eq_operands;
|
||||
|
@ -233,7 +234,7 @@ bool InstructionWithinBranchIdentical(
|
|||
auto old_channel_id = instruction->channel_id();
|
||||
instruction->set_channel_id(instructions[0]->channel_id());
|
||||
bool eq_instructions = instructions[0]->Identical(
|
||||
*instruction, eq_operand, eq_computations, is_layout_senstive);
|
||||
*instruction, eq_operand, eq_computations, is_layout_sensitive);
|
||||
instruction->set_channel_id(old_channel_id);
|
||||
return eq_instructions;
|
||||
});
|
||||
|
@ -243,7 +244,7 @@ bool InstructionWithinBranchIdentical(
|
|||
[&](HloInstruction* instruction) {
|
||||
return instructions[0]->Identical(
|
||||
*instruction, eq_operand, eq_computations,
|
||||
is_layout_senstive);
|
||||
is_layout_sensitive);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -354,12 +355,228 @@ Status RemoveInstructionFromComputation(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Identify converts to be hoisted/rematerialized out of the branch
|
||||
// computations.
|
||||
absl::flat_hash_set<int64> FindSpecialConverts(HloInstruction* old_root,
|
||||
int branch_count,
|
||||
HloInstruction* conditional,
|
||||
bool is_layout_sensitive) {
|
||||
absl::flat_hash_set<int64> kspecial_convert;
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) {
|
||||
continue;
|
||||
}
|
||||
bool replica = true;
|
||||
HloInstruction* kspecial_convert_candidate =
|
||||
old_root->mutable_operand(operand_num);
|
||||
// Check whether an identical candidate appears in other branches
|
||||
for (int others = 1; others < branch_count; ++others) {
|
||||
HloInstruction* others_root =
|
||||
conditional->branch_computation(others)->root_instruction();
|
||||
bool eq_shape =
|
||||
is_layout_sensitive
|
||||
? ShapeUtil::Equal(others_root->operand(operand_num)->shape(),
|
||||
kspecial_convert_candidate->shape())
|
||||
: ShapeUtil::Compatible(
|
||||
others_root->operand(operand_num)->shape(),
|
||||
kspecial_convert_candidate->shape());
|
||||
if ((others_root->operand(operand_num)->opcode() ==
|
||||
HloOpcode::kConvert) &&
|
||||
eq_shape) {
|
||||
// Nothing to be done.
|
||||
} else {
|
||||
replica = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (replica) {
|
||||
kspecial_convert.insert(operand_num);
|
||||
}
|
||||
}
|
||||
return kspecial_convert;
|
||||
}
|
||||
|
||||
// Restructuring the conditional instruction as follows:
|
||||
// i.e., %result = conditional() becomes
|
||||
// x = conditional()
|
||||
// y.{0..n} = gte(x, {0..n})
|
||||
// z = tuple(y.0, y.1, ...y.n)
|
||||
// Doing so ensures that we can accommodate the possible shape-change of the
|
||||
// conditional when the instructions are hoisted.
|
||||
Status RestructureConditionalInstruction(HloComputation* computation,
|
||||
HloInstruction* conditional) {
|
||||
HloInstruction* old_root = computation->root_instruction();
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
int cur_index = 0;
|
||||
for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape());
|
||||
++cur_index) {
|
||||
new_operands.push_back(
|
||||
computation->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index),
|
||||
conditional, cur_index)));
|
||||
}
|
||||
HloInstruction* new_tuple =
|
||||
computation->AddInstruction(HloInstruction::CreateTuple(new_operands));
|
||||
if (old_root == conditional) {
|
||||
computation->set_root_instruction(new_tuple);
|
||||
} else {
|
||||
std::vector<HloInstruction*> new_tuple_users;
|
||||
for (auto conditional_user : conditional->users()) {
|
||||
auto is_new_gte = absl::c_find_if(
|
||||
new_operands,
|
||||
[&](HloInstruction* instr) { return instr == conditional_user; });
|
||||
if (is_new_gte == new_operands.end()) {
|
||||
new_tuple_users.push_back(conditional_user);
|
||||
}
|
||||
}
|
||||
for (auto new_tuple_user : new_tuple_users) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
conditional->ReplaceUseWith(new_tuple_user, new_tuple));
|
||||
}
|
||||
}
|
||||
VLOG(2) << "computation after root restructure:\n" << computation->ToString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> ConvertSpecialMove(HloInstruction* conditional,
|
||||
bool is_layout_sensitive) {
|
||||
int branch_count = conditional->branch_count();
|
||||
if (branch_count <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInstruction* old_root =
|
||||
conditional->branch_computation(0)->root_instruction();
|
||||
if (old_root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
} else {
|
||||
VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString();
|
||||
// Identify the gte using `index'.
|
||||
auto find_gte = [](const HloInstruction* conditional_result,
|
||||
int64 index) -> HloInstruction* {
|
||||
for (HloInstruction* instr : conditional_result->users()) {
|
||||
if (instr->opcode() != HloOpcode::kGetTupleElement) {
|
||||
return nullptr;
|
||||
}
|
||||
if (instr->tuple_index() == index) {
|
||||
return instr;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
// Captures tuple indices refering to converts to be rematerialized/hoisted.
|
||||
absl::flat_hash_set<int64> kspecial_convert = FindSpecialConverts(
|
||||
old_root, branch_count, conditional, is_layout_sensitive);
|
||||
|
||||
// Exit if we cannot find any converts to be hoisted.
|
||||
if (kspecial_convert.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestructureConditionalInstruction(conditional->parent(), conditional));
|
||||
|
||||
for (int branch = 0; branch < branch_count; branch++) {
|
||||
old_root = conditional->branch_computation(branch)->root_instruction();
|
||||
absl::flat_hash_map<HloInstruction*, int64> map_inst_to_tuple_index;
|
||||
std::vector<HloInstruction*> new_operands(old_root->operand_count());
|
||||
std::unordered_set<HloInstruction*> to_hoist_set;
|
||||
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] =
|
||||
operand_num;
|
||||
}
|
||||
for (int64 operand_num = 0; operand_num < old_root->operand_count();
|
||||
++operand_num) {
|
||||
HloInstruction* hoist = old_root->mutable_operand(operand_num);
|
||||
if (!kspecial_convert.contains(operand_num)) {
|
||||
new_operands[operand_num] = old_root->mutable_operand(operand_num);
|
||||
continue;
|
||||
}
|
||||
|
||||
to_hoist_set.insert(hoist);
|
||||
int64 new_tuple_count = old_root->operand_count();
|
||||
|
||||
// Replace the hoisted instr in the tuple with the operand/operands.
|
||||
// We will replace at least one of the operands of the hoist at the
|
||||
// tuple place; the rest will be added at the end.
|
||||
bool inplace = true;
|
||||
CHECK(!hoist->operands().empty());
|
||||
for (HloInstruction* prod : hoist->operands()) {
|
||||
if (inplace) {
|
||||
map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist];
|
||||
new_operands[map_inst_to_tuple_index[hoist]] = prod;
|
||||
inplace = false;
|
||||
} else {
|
||||
map_inst_to_tuple_index[prod] = new_tuple_count++;
|
||||
new_operands.push_back(prod);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the new root instruction.
|
||||
HloComputation* cur_branch = conditional->branch_computation(branch);
|
||||
HloInstruction* new_branch_root =
|
||||
cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands));
|
||||
// The shape can vary since the operands to convert are now
|
||||
// being returned through the branches' root.
|
||||
cur_branch->set_root_instruction(new_branch_root, true /*new shape*/);
|
||||
TF_CHECK_OK(cur_branch->RemoveInstruction(old_root));
|
||||
|
||||
// Only one of the branches needs to change the conditional->parent().
|
||||
if (branch != 0) {
|
||||
continue;
|
||||
}
|
||||
HloComputation* conditional_parent = conditional->parent();
|
||||
HloInstruction* newconditional =
|
||||
conditional_parent->AddInstruction(HloInstruction::CreateConditional(
|
||||
cur_branch->root_instruction()->shape(),
|
||||
conditional->mutable_operand(0),
|
||||
absl::MakeSpan(conditional->branch_computations()),
|
||||
absl::MakeSpan(conditional->operands()).subspan(1)));
|
||||
// Ensure that all the users of conditional refer to the new one.
|
||||
TF_RETURN_IF_ERROR(
|
||||
conditional->ReplaceAllUsesWithDifferentShape(newconditional));
|
||||
TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional));
|
||||
conditional = newconditional;
|
||||
// Add the hoisted instructions in the parent.
|
||||
for (HloInstruction* hoist : to_hoist_set) {
|
||||
VLOG(2) << "Hoisting instruction:" << hoist->ToString();
|
||||
int64 hoist_index = map_inst_to_tuple_index[hoist];
|
||||
// Find out the gte that captured the hoisted instr result.
|
||||
HloInstruction* gte_hoist = find_gte(conditional, hoist_index);
|
||||
CHECK(gte_hoist != nullptr);
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
for (HloInstruction* op : hoist->operands()) {
|
||||
HloInstruction* gte = conditional_parent->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(
|
||||
op->shape(), conditional, map_inst_to_tuple_index[op]));
|
||||
new_operands.push_back(gte);
|
||||
}
|
||||
HloInstruction* hoisted = conditional_parent->AddInstruction(
|
||||
hoist->CloneWithNewOperands(hoist->shape(), new_operands));
|
||||
VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString();
|
||||
TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted));
|
||||
TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist));
|
||||
}
|
||||
// No need to explicitly delete a hoisted instruction since if its dead
|
||||
// then the subsequent DCE will remove it.
|
||||
}
|
||||
}
|
||||
VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString();
|
||||
return true;
|
||||
}
|
||||
|
||||
// Hoist identical ops out of the conditional. The definition of identical
|
||||
// are the shape of the operands are identical and their properties are
|
||||
// identical. Will start from the root instruction of each branch and get
|
||||
// the identical ops to hoist.
|
||||
StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
||||
bool is_layout_sensitive) {
|
||||
VLOG(1) << " visiting conditional:" << conditional->ToString();
|
||||
int branch_count = conditional->branch_count();
|
||||
if (branch_count <= 0) {
|
||||
return false;
|
||||
|
@ -399,7 +616,7 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
|||
}
|
||||
}
|
||||
|
||||
if (visitors[0].HoistInstructionSize() <= 1) {
|
||||
if (visitors[0].HoistInstructionSize() < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -442,7 +659,6 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
|||
RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(),
|
||||
conditional->branch_computation(i)));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -451,26 +667,55 @@ StatusOr<bool> MergeIdenticalElements(HloInstruction* conditional,
|
|||
StatusOr<bool> ConditionalCodeMotion::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
|
||||
// Gather all the conditional ops in our module. We do this ahead of time so
|
||||
// we don't have to worry about mutating the lists of computations or
|
||||
// instructions as we iterate.
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() == HloOpcode::kConditional) {
|
||||
conditional_ops.push_back(instr);
|
||||
if (pursue_full_conditional_code_motion_) {
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() == HloOpcode::kConditional) {
|
||||
conditional_ops.push_back(instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (HloInstruction* conditional_op : conditional_ops) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool result,
|
||||
MergeIdenticalElements(conditional_op, is_layout_sensitive_));
|
||||
changed |= result;
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
HloPassPipeline subpipeline("after_conditional_code_motion");
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
subpipeline.AddPass<TupleSimplifier>();
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
||||
changed |= cleanup_changed;
|
||||
}
|
||||
}
|
||||
|
||||
for (HloInstruction* conditional_op : conditional_ops) {
|
||||
TF_ASSIGN_OR_RETURN(bool result, MergeIdenticalElements(
|
||||
conditional_op, is_layout_sensitive_));
|
||||
changed |= result;
|
||||
// handling convert rematerialization/hoisting
|
||||
{
|
||||
std::vector<HloInstruction*> conditional_ops;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
for (auto* instr : comp->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() == HloOpcode::kConditional) {
|
||||
conditional_ops.push_back(instr);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (HloInstruction* conditional_op : conditional_ops) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool convert_result,
|
||||
ConvertSpecialMove(conditional_op, is_layout_sensitive_));
|
||||
changed |= convert_result;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
HloPassPipeline subpipeline("after_conditional_code_motion");
|
||||
HloPassPipeline subpipeline(
|
||||
"after_conditional_code_motion_after_convert_hoisting");
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
subpipeline.AddPass<TupleSimplifier>();
|
||||
subpipeline.AddPass<HloDCE>();
|
||||
TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module));
|
||||
|
|
|
@ -23,7 +23,11 @@ limitations under the License.
|
|||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass that moves identical ops out of conditional.
|
||||
// ConditionalCodeMotion specializes in hoisting/rematerializing
|
||||
// unconditional converts in the default mode.
|
||||
// When pursue_full_conditional_code_motion_ is set to true, the
|
||||
// full HLO pass moves identical ops out of a conditional in addition to moving
|
||||
// converts.
|
||||
// - The definition of identical are the shape of the operands are identical
|
||||
// and their properties are identical.
|
||||
// - Currently, only some types of instructions is supported.
|
||||
|
@ -35,13 +39,18 @@ class ConditionalCodeMotion : public HloModulePass {
|
|||
public:
|
||||
// If is_layout_sensitive is true, then the hoist process preserves layout
|
||||
// during identical comparison. Otherwise, layout is ignored.
|
||||
explicit ConditionalCodeMotion(bool is_layout_sensitive = true)
|
||||
: is_layout_sensitive_(is_layout_sensitive) {}
|
||||
explicit ConditionalCodeMotion(
|
||||
bool is_layout_sensitive = true,
|
||||
bool pursue_full_conditional_code_motion = false)
|
||||
: is_layout_sensitive_(is_layout_sensitive),
|
||||
pursue_full_conditional_code_motion_(
|
||||
pursue_full_conditional_code_motion) {}
|
||||
absl::string_view name() const override { return "conditional-code-motion"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
const bool is_layout_sensitive_;
|
||||
const bool pursue_full_conditional_code_motion_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
@ -38,7 +38,86 @@ namespace {
|
|||
using ConditionalCodeMotionTest = HloTestBase;
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, DoNotMoveConvertOut) {
|
||||
TEST_F(ConditionalCodeMotionTest, MoveSubsetTupleOut) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
||||
on_true {
|
||||
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
|
||||
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
|
||||
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493)
|
||||
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.2894, %reshape.8493)
|
||||
}
|
||||
|
||||
on_false {
|
||||
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
|
||||
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
|
||||
%add = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
|
||||
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.3604, %add)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
|
||||
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||
conditional = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
||||
get-first-index.2 = f32[2,512,364]{2,1,0} get-tuple-element(conditional), index=1
|
||||
ROOT result = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(get-first-index, get-first-index.2)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
||||
on_true {
|
||||
%arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0
|
||||
%reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1)
|
||||
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493)
|
||||
%convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493)
|
||||
ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894)
|
||||
}
|
||||
|
||||
on_false {
|
||||
%arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0)
|
||||
%get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0
|
||||
%reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3)
|
||||
%add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||
%sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717)
|
||||
%convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"}
|
||||
ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
pred.1 = pred[] parameter(0)
|
||||
arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1)
|
||||
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||
ROOT conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Convert())));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, MoveConvertOut) {
|
||||
absl::string_view hlo_string =
|
||||
R"(
|
||||
HloModule RemoveDotOpOut
|
||||
|
@ -65,12 +144,16 @@ ENTRY main {
|
|||
arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2)
|
||||
conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false
|
||||
get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0
|
||||
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index)
|
||||
add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index)
|
||||
ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert()))));
|
||||
}
|
||||
|
||||
TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) {
|
||||
|
@ -123,7 +206,7 @@ ENTRY main {
|
|||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
const HloInstruction* conditional =
|
||||
|
@ -181,7 +264,7 @@ ENTRY main {
|
|||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
|
@ -245,7 +328,7 @@ ENTRY main {
|
|||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
|
||||
const HloInstruction* conditional =
|
||||
|
@ -317,7 +400,7 @@ ENTRY main {
|
|||
)";
|
||||
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_FALSE(pass.Run(&*module).ValueOrDie());
|
||||
}
|
||||
|
||||
|
@ -390,7 +473,7 @@ ENTRY main {
|
|||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
|
||||
ConditionalCodeMotion pass;
|
||||
ConditionalCodeMotion pass(true, true);
|
||||
ASSERT_TRUE(pass.Run(&*module).ValueOrDie());
|
||||
const HloInstruction* conditional =
|
||||
FindInstruction(module.get(), "conditional");
|
||||
|
|
|
@ -226,6 +226,11 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
|
|||
dims_to_keep.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
// We support fast codegen for three cases:
|
||||
// 1) Row reduction: (K, R)
|
||||
// 2) Column reduction: (K, R, K)
|
||||
// 3) "Batched" row reduction: (R, K, R)
|
||||
if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||
dims_to_keep) &&
|
||||
!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||
|
|
|
@ -77,8 +77,6 @@ class KernelThunk : public Thunk {
|
|||
// Will be set by IrEmitterUnnested.
|
||||
LaunchDimensions launch_dimensions_;
|
||||
|
||||
// Describes how to load this kernel. ExecuteOnStream reuses this loader
|
||||
// specification for all executions.
|
||||
mutable tensorflow::mutex mutex_;
|
||||
|
||||
// Loaded kernels for each `StreamExecutor`. Requires pointer stability of
|
||||
|
|
|
@ -3908,6 +3908,10 @@ const string& HloInstruction::outfeed_config() const {
|
|||
return Cast<HloOutfeedInstruction>(this)->outfeed_config();
|
||||
}
|
||||
|
||||
void HloInstruction::set_outfeed_config(const string& config) {
|
||||
return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
|
||||
}
|
||||
|
||||
const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
|
||||
return Cast<HloCollectiveInstruction>(this)->replica_groups();
|
||||
}
|
||||
|
|
|
@ -1755,6 +1755,9 @@ class HloInstruction {
|
|||
// Returns the config for the Outfeed instruction.
|
||||
const string& outfeed_config() const;
|
||||
|
||||
// Delegates to HloOutfeedInstruction::set_outfeed_config.
|
||||
void set_outfeed_config(const string& config);
|
||||
|
||||
// Returns the shape for the Outfeed instruction.
|
||||
const Shape& outfeed_shape() const;
|
||||
|
||||
|
|
|
@ -1141,6 +1141,7 @@ class HloOutfeedInstruction : public HloInstruction {
|
|||
const Shape& outfeed_shape() const { return outfeed_shape_; }
|
||||
// Returns the config for the Outfeed instruction.
|
||||
const string& outfeed_config() const { return outfeed_config_; }
|
||||
void set_outfeed_config(const string& config) { outfeed_config_ = config; }
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
|
|
|
@ -29,36 +29,78 @@ StatusOr<bool> MemorySpacePropagation::Run(HloModule* module) {
|
|||
// Propagate the operand subshapes.
|
||||
for (int operand_idx = 0; operand_idx < instruction->operand_count();
|
||||
++operand_idx) {
|
||||
modified |=
|
||||
PropagateSubshapes(instruction->operand(operand_idx)->shape(),
|
||||
instruction->fused_parameter(operand_idx));
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(
|
||||
instruction->operand(operand_idx)->shape())) {
|
||||
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
||||
modified |= Propagate(indexed_shape.index,
|
||||
instruction->fused_parameter(operand_idx),
|
||||
memory_space);
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate output subshapes.
|
||||
modified |= PropagateSubshapes(instruction->shape(),
|
||||
instruction->fused_expression_root());
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(instruction->shape())) {
|
||||
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
||||
modified |=
|
||||
Propagate(indexed_shape.index,
|
||||
instruction->fused_expression_root(), memory_space);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
bool MemorySpacePropagation::PropagateSubshapes(
|
||||
const Shape& caller_shape, const HloInstruction* callee_instruction) const {
|
||||
bool MemorySpacePropagation::Propagate(ShapeIndexView index,
|
||||
const HloInstruction* callee_instruction,
|
||||
int64 memory_space) const {
|
||||
bool modified = false;
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(caller_shape)) {
|
||||
int64 memory_space = indexed_shape.shape.layout().memory_space();
|
||||
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
|
||||
callee_instruction, indexed_shape.index);
|
||||
const HloValue& value = dataflow_analysis_->GetUniqueValueAt(
|
||||
callee_instruction, index.ToShapeIndex());
|
||||
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
Shape* shape = ShapeUtil::GetMutableSubshape(
|
||||
position.instruction->mutable_shape(), position.index);
|
||||
if (shape->layout().memory_space() != memory_space) {
|
||||
shape->mutable_layout()->set_memory_space(memory_space);
|
||||
modified = true;
|
||||
}
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
HloInstruction* instruction = position.instruction;
|
||||
Shape* shape = ShapeUtil::GetMutableSubshape(instruction->mutable_shape(),
|
||||
position.index);
|
||||
if (shape->layout().memory_space() == memory_space) {
|
||||
continue;
|
||||
}
|
||||
shape->mutable_layout()->set_memory_space(memory_space);
|
||||
modified = true;
|
||||
|
||||
// For fusion outputs, propagate the memory space to the fusion root.
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
Propagate(position.index, instruction->fused_expression_root(),
|
||||
memory_space);
|
||||
}
|
||||
|
||||
const HloInstruction* parent_fusion =
|
||||
instruction->parent()->FusionInstruction();
|
||||
// For nested fusion roots, pop one level up and propagate the memory space
|
||||
// to the output of the calling fusion instruction.
|
||||
if (instruction == instruction->parent()->root_instruction() &&
|
||||
parent_fusion->parent()->IsFusionComputation()) {
|
||||
Propagate(position.index, parent_fusion, memory_space);
|
||||
}
|
||||
|
||||
// For nested fusion parameters, pop one level up and propagate the memory
|
||||
// space to the operand of the calling fusion instruction.
|
||||
if (instruction->opcode() == HloOpcode::kParameter &&
|
||||
parent_fusion->parent()->IsFusionComputation()) {
|
||||
const HloInstruction* fusion_operand =
|
||||
parent_fusion->operand(instruction->parameter_number());
|
||||
Propagate(position.index, fusion_operand, memory_space);
|
||||
}
|
||||
}
|
||||
|
||||
for (const HloUse& use : value.uses()) {
|
||||
// For fusion uses, propagate the memory space to the fusion parameter.
|
||||
if (use.instruction->opcode() == HloOpcode::kFusion) {
|
||||
modified |= Propagate(
|
||||
use.operand_index,
|
||||
use.instruction->fused_parameter(use.operand_number), memory_space);
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
|
|
|
@ -31,12 +31,11 @@ class MemorySpacePropagation : public HloModulePass {
|
|||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
// Given the caller shape (operand or output) and its corresponding
|
||||
// insturction in the fused computation (parameter or root), propagates the
|
||||
// memory space to all the subshapes in the callee side. Returns true if the
|
||||
// module is modified.
|
||||
bool PropagateSubshapes(const Shape& caller_shape,
|
||||
const HloInstruction* callee_instruction) const;
|
||||
// Given the shape index (operand or output) and its corresponding instruction
|
||||
// in the fused computation (parameter or root), propagates the memory space
|
||||
// in the callee side. Returns true if the module is modified.
|
||||
bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction,
|
||||
int64 memory_space) const;
|
||||
|
||||
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
||||
};
|
||||
|
|
|
@ -199,5 +199,153 @@ TEST_F(MemorySpacePropagationTest, TupleOutput) {
|
|||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
TEST_F(MemorySpacePropagationTest, NestedInputFusion) {
|
||||
// Tests propagating the memory space to nested fusions on the input side.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule NestedFusion
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[3,2]{0,1:T(128)} parameter(0)
|
||||
ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[3,2]{0,1:T(128)} parameter(0)
|
||||
%fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
|
||||
ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
absl::string_view expected_hlo_string = R"(
|
||||
HloModule NestedFusion
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0)
|
||||
ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0)
|
||||
%fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion
|
||||
ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[3,2]{0,1:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[6]{0:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
MemorySpacePropagation memory_space_propagation;
|
||||
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||
TF_EXPECT_OK(Verify(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ref,
|
||||
ParseAndReturnVerifiedModule(expected_hlo_string));
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
TEST_F(MemorySpacePropagationTest, NestedOutputFusion) {
|
||||
// Tests propagating the memory space to nested fusions on the output side.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule NestedFusion
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[6]{0:T(128)} parameter(0)
|
||||
ROOT %bitcast = s32[3,2]{0,1:T(128)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)} parameter(0)
|
||||
%add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
|
||||
ROOT %fusion.1 = s32[3,2]{0,1:T(128)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
absl::string_view expected_hlo_string = R"(
|
||||
HloModule NestedFusion
|
||||
|
||||
%bitcast_fusion {
|
||||
%bf_param = s32[6]{0:T(128)S(1)} parameter(0)
|
||||
ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param_1.3 = s32[1]{0:T(128)} parameter(1)
|
||||
%constant.2 = s32[]{:T(128)} constant(-2147483648)
|
||||
%pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
|
||||
%param_2.3 = s32[5]{0:T(128)S(1)} parameter(2)
|
||||
%pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
|
||||
%maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
|
||||
%param_0.1 = s32[6]{0:T(128)S(1)} parameter(0)
|
||||
%add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1)
|
||||
ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param0 = s32[6]{0:T(128)} parameter(0)
|
||||
%param1 = s32[1]{0:T(128)} parameter(1)
|
||||
%param2 = s32[5]{0:T(128)} parameter(2)
|
||||
%arg0 = s32[6]{0:T(128)S(1)} copy(%param0)
|
||||
%arg1 = s32[1]{0:T(128)} copy(%param1)
|
||||
%arg2 = s32[5]{0:T(128)S(1)} copy(%param2)
|
||||
%fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation
|
||||
ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnUnverifiedModule(hlo_string));
|
||||
MemorySpacePropagation memory_space_propagation;
|
||||
EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie());
|
||||
TF_EXPECT_OK(Verify(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto ref,
|
||||
ParseAndReturnVerifiedModule(expected_hlo_string));
|
||||
EXPECT_EQ(module->Hash(), ref->Hash());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
@ -552,7 +552,7 @@ class LowerToNVVMPass
|
|||
// TODO(csigg): Remove once we support replacing non-root ops.
|
||||
target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
|
||||
::mlir::gpu::YieldOp>();
|
||||
if (failed(mlir::applyFullConversion(m, target, patterns, &converter))) {
|
||||
if (failed(mlir::applyFullConversion(m, target, patterns))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,16 +52,26 @@ cc_library(
|
|||
name = "test_macros_header",
|
||||
testonly = True,
|
||||
hdrs = ["test_macros.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
# Generate a test_macros_${BACKEND} library per backend with the proper copts.
|
||||
generate_backend_test_macros()
|
||||
|
||||
cc_library(
|
||||
name = "manifest_checking_test",
|
||||
testonly = True,
|
||||
srcs = ["manifest_checking_test.cc"],
|
||||
hdrs = ["manifest_checking_test.h"],
|
||||
deps = [
|
||||
":test_macros_header",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_utils",
|
||||
srcs = ["test_utils.cc"],
|
||||
|
@ -136,6 +146,7 @@ cc_library(
|
|||
hdrs = ["hlo_test_base.h"],
|
||||
deps = [
|
||||
":literal_test_util",
|
||||
":manifest_checking_test",
|
||||
":test_utils",
|
||||
":verified_hlo_module",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
|
@ -193,6 +204,7 @@ cc_library(
|
|||
srcs = ["client_library_test_base.cc"],
|
||||
hdrs = ["client_library_test_base.h"],
|
||||
deps = [
|
||||
":manifest_checking_test",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
|
@ -273,6 +285,7 @@ cc_library(
|
|||
hdrs = ["local_client_test_base.h"],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
":manifest_checking_test",
|
||||
":verified_hlo_module",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
|
|
@ -266,11 +266,6 @@ def generate_backend_test_macros(backends = []):
|
|||
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/platform:logging",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/bitmap.h"
|
||||
|
@ -62,7 +63,7 @@ std::vector<TestCase> ExpandUseBfloat16(
|
|||
}
|
||||
|
||||
// A client library test establishes an in-process XLA client connection.
|
||||
class ClientLibraryTestBase : public ::testing::Test {
|
||||
class ClientLibraryTestBase : public ManifestCheckingTest {
|
||||
protected:
|
||||
explicit ClientLibraryTestBase(se::Platform* platform = nullptr);
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -67,7 +68,7 @@ namespace xla {
|
|||
// )
|
||||
//
|
||||
// For a more detailed example, see "../tests/sample_text_test.cc".
|
||||
class HloTestBase : public ::testing::Test {
|
||||
class HloTestBase : public ManifestCheckingTest {
|
||||
public:
|
||||
// Creates a new HLO module for a test. The module created will have
|
||||
// TestName() for its name; it will also automatically populate its debug
|
||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||
#include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
@ -75,7 +76,7 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator {
|
|||
};
|
||||
|
||||
// A base class for tests which exercise the LocalClient interface.
|
||||
class LocalClientTestBase : public ::testing::Test {
|
||||
class LocalClientTestBase : public ManifestCheckingTest {
|
||||
protected:
|
||||
struct EigenThreadPoolWrapper;
|
||||
explicit LocalClientTestBase(se::Platform* platform = nullptr);
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
|
||||
// disabled - a sequence of regexps.
|
||||
using ManifestT = absl::flat_hash_map<std::string, std::vector<std::string>>;
|
||||
|
||||
ManifestT ReadManifest() {
|
||||
ManifestT manifest;
|
||||
|
||||
absl::string_view path = absl::NullSafeStringView(kDisabledManifestPath);
|
||||
if (path.empty()) {
|
||||
return manifest;
|
||||
}
|
||||
|
||||
// Note: parens are required to disambiguate vs function decl.
|
||||
std::ifstream file_stream((std::string(path)));
|
||||
std::string contents((std::istreambuf_iterator<char>(file_stream)),
|
||||
std::istreambuf_iterator<char>());
|
||||
|
||||
std::vector<std::string> lines = absl::StrSplit(contents, '\n');
|
||||
for (std::string& line : lines) {
|
||||
auto comment = line.find("//");
|
||||
if (comment != std::string::npos) {
|
||||
line = line.substr(0, comment);
|
||||
}
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
absl::StripTrailingAsciiWhitespace(&line);
|
||||
std::vector<std::string> pieces = absl::StrSplit(line, ' ');
|
||||
CHECK_GE(pieces.size(), 1);
|
||||
auto& platforms = manifest[pieces[0]];
|
||||
for (size_t i = 1; i < pieces.size(); ++i) {
|
||||
platforms.push_back(pieces[i]);
|
||||
}
|
||||
}
|
||||
return manifest;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ManifestCheckingTest::SetUp() {
|
||||
const testing::TestInfo* test_info =
|
||||
testing::UnitTest::GetInstance()->current_test_info();
|
||||
absl::string_view test_case_name = test_info->test_suite_name();
|
||||
absl::string_view test_name = test_info->name();
|
||||
VLOG(1) << "test_case_name: " << test_case_name;
|
||||
VLOG(1) << "test_name: " << test_name;
|
||||
|
||||
// Remove the type suffix from the test case name.
|
||||
if (const char* type_param = test_info->type_param()) {
|
||||
VLOG(1) << "type_param: " << type_param;
|
||||
size_t last_slash = test_case_name.rfind('/');
|
||||
test_case_name = test_case_name.substr(0, last_slash);
|
||||
VLOG(1) << "test_case_name: " << test_case_name;
|
||||
}
|
||||
|
||||
// Remove the test instantiation name if it is present.
|
||||
auto first_slash = test_case_name.find('/');
|
||||
if (first_slash != test_case_name.npos) {
|
||||
test_case_name.remove_prefix(first_slash + 1);
|
||||
VLOG(1) << "test_case_name: " << test_case_name;
|
||||
}
|
||||
|
||||
ManifestT manifest = ReadManifest();
|
||||
|
||||
// If the test name ends with a slash followed by one or more characters,
|
||||
// strip that off.
|
||||
auto last_slash = test_name.rfind('/');
|
||||
if (last_slash != test_name.npos) {
|
||||
test_name = test_name.substr(0, last_slash);
|
||||
VLOG(1) << "test_name: " << test_name;
|
||||
}
|
||||
|
||||
// First try full match: test_case_name.test_name
|
||||
// If that fails, try to find just the test_case_name; this would disable all
|
||||
// tests in the test case.
|
||||
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
|
||||
if (it == manifest.end()) {
|
||||
it = manifest.find(test_case_name);
|
||||
if (it == manifest.end()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Expect a full match vs. one of the platform regexps to disable the test.
|
||||
const std::vector<std::string>& disabled_platforms = it->second;
|
||||
auto platform_string = kTestPlatform;
|
||||
for (const auto& s : disabled_platforms) {
|
||||
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
|
||||
GTEST_SKIP();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// We didn't hit in the disabled manifest entries, so don't disable it.
|
||||
}
|
||||
|
||||
} // namespace xla
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// This class allows us to intercept the test name and use an arbitrary
|
||||
// heuristic to decide whether the test case should be disabled. We
|
||||
// determine whether the test case should be disabled by resolving the (test
|
||||
// case name, test name) in a manifest file.
|
||||
class ManifestCheckingTest : public ::testing::Test {
|
||||
protected:
|
||||
// This method runs before each test runs.
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_
|
|
@ -15,93 +15,18 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <streambuf>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is
|
||||
// disabled - a sequence of regexps.
|
||||
using ManifestT = absl::flat_hash_map<string, std::vector<string>>;
|
||||
|
||||
ManifestT ReadManifest() {
|
||||
ManifestT manifest;
|
||||
|
||||
string path = XLA_DISABLED_MANIFEST;
|
||||
if (path.empty()) {
|
||||
return manifest;
|
||||
}
|
||||
|
||||
std::ifstream file_stream(path);
|
||||
// Note: parens are required to disambiguate vs function decl.
|
||||
string contents((std::istreambuf_iterator<char>(file_stream)),
|
||||
std::istreambuf_iterator<char>());
|
||||
|
||||
std::vector<string> lines = absl::StrSplit(contents, '\n');
|
||||
for (string& line : lines) {
|
||||
auto comment = line.find("//");
|
||||
if (comment != string::npos) {
|
||||
line = line.substr(0, comment);
|
||||
}
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
absl::StripTrailingAsciiWhitespace(&line);
|
||||
std::vector<string> pieces = absl::StrSplit(line, ' ');
|
||||
CHECK_GE(pieces.size(), 1);
|
||||
auto& platforms = manifest[pieces[0]];
|
||||
for (int64 i = 1; i < pieces.size(); ++i) {
|
||||
platforms.push_back(pieces[i]);
|
||||
}
|
||||
}
|
||||
return manifest;
|
||||
static bool InitModule() {
|
||||
kDisabledManifestPath = XLA_DISABLED_MANIFEST;
|
||||
VLOG(1) << "kDisabledManifestPath: " << kDisabledManifestPath;
|
||||
kTestPlatform = XLA_PLATFORM;
|
||||
VLOG(1) << "kTestPlatform: " << kTestPlatform;
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
|
||||
absl::string_view test_name) {
|
||||
ManifestT manifest = ReadManifest();
|
||||
|
||||
// If the test name ends with a slash followed by one or more digits, strip
|
||||
// that off; this is just a shard number, and matching on this would be
|
||||
// unstable even if someone wanted to do it.
|
||||
static LazyRE2 shard_num_pattern = {R"(/\d+$)"};
|
||||
absl::string_view suffix;
|
||||
if (RE2::PartialMatch(test_name, *shard_num_pattern, &suffix)) {
|
||||
test_name.remove_suffix(suffix.size());
|
||||
}
|
||||
|
||||
// First try full match: test_case_name.test_name
|
||||
// If that fails, try to find just the test_case_name; this would disable all
|
||||
// tests in the test case.
|
||||
auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name));
|
||||
if (it == manifest.end()) {
|
||||
it = manifest.find(test_case_name);
|
||||
if (it == manifest.end()) {
|
||||
return std::string(test_name);
|
||||
}
|
||||
}
|
||||
|
||||
// Expect a full match vs. one of the platform regexps to disable the test.
|
||||
const std::vector<string>& disabled_platforms = it->second;
|
||||
string platform_string = XLA_PLATFORM;
|
||||
for (const auto& s : disabled_platforms) {
|
||||
if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) {
|
||||
return absl::StrCat("DISABLED_", test_name);
|
||||
}
|
||||
}
|
||||
|
||||
// We didn't hit in the disabled manifest entries, so don't disable it.
|
||||
return std::string(test_name);
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
@ -28,12 +28,6 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#define DISABLED_ON_CPU(X) X
|
||||
#define DISABLED_ON_GPU(X) X
|
||||
#define DISABLED_ON_GPU_ROCM(X) X
|
||||
|
@ -79,117 +73,15 @@ limitations under the License.
|
|||
|
||||
namespace xla {
|
||||
|
||||
// Reads a disabled manifest file to resolve whether test cases should be
|
||||
// disabled on a particular platform. For a test that should be disabled,
|
||||
// returns DISABLED_ prepended to its name; otherwise returns the test name
|
||||
// unmodified.
|
||||
std::string PrependDisabledIfIndicated(absl::string_view test_case_name,
|
||||
absl::string_view test_name);
|
||||
inline const char *kDisabledManifestPath = nullptr;
|
||||
inline const char *kTestPlatform = nullptr;
|
||||
|
||||
} // namespace xla
|
||||
|
||||
// This is the internal "gtest" class instantiation -- it is identical to the
|
||||
// GTEST_TEST_ macro, except that we intercept the test name for potential
|
||||
// modification by PrependDisabledIfIndicated. That file can use an arbitrary
|
||||
// heuristic to decide whether the test case should be disabled, and we
|
||||
// determine whether the test case should be disabled by resolving the (test
|
||||
// case name, test name) in a manifest file.
|
||||
#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \
|
||||
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
|
||||
: public parent_class { \
|
||||
public: \
|
||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
|
||||
\
|
||||
private: \
|
||||
virtual void TestBody(); \
|
||||
static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
|
||||
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
|
||||
test_name)); \
|
||||
}; \
|
||||
\
|
||||
::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \
|
||||
test_name)::test_info_ = \
|
||||
::testing::RegisterTest( \
|
||||
#test_case_name, \
|
||||
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
|
||||
.c_str(), \
|
||||
nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \
|
||||
return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \
|
||||
}); \
|
||||
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
|
||||
#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
|
||||
|
||||
// This is identical to the TEST_F macro from "gtest", but it potentially
|
||||
// disables the test based on an external manifest file, DISABLED_MANIFEST.
|
||||
//
|
||||
// Per usual, you can see what tests are available via --gunit_list_tests and
|
||||
// choose to run tests that have been disabled via the manifest via
|
||||
// --gunit_also_run_disabled_tests.
|
||||
#define XLA_TEST_F(test_fixture, test_name) \
|
||||
XLA_GTEST_TEST_(test_fixture, test_name, test_fixture)
|
||||
#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
|
||||
|
||||
// Likewise, this is identical to the TEST_P macro from "gtest", but
|
||||
// potentially disables the test based on the DISABLED_MANIFEST file.
|
||||
//
|
||||
// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will
|
||||
// be properly expanded before the stringification occurs.
|
||||
#define XLA_TEST_P_IMPL_(test_case_name, test_name) \
|
||||
class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \
|
||||
: public test_case_name { \
|
||||
public: \
|
||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \
|
||||
virtual void TestBody(); \
|
||||
\
|
||||
private: \
|
||||
static int AddToRegistry() { \
|
||||
::testing::UnitTest::GetInstance() \
|
||||
->parameterized_test_registry() \
|
||||
.GetTestCasePatternHolder<test_case_name>( \
|
||||
#test_case_name, \
|
||||
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
|
||||
->AddTestPattern( \
|
||||
#test_case_name, \
|
||||
::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \
|
||||
.c_str(), \
|
||||
new ::testing::internal::TestMetaFactory<GTEST_TEST_CLASS_NAME_( \
|
||||
test_case_name, test_name)>()); \
|
||||
return 0; \
|
||||
} \
|
||||
static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
|
||||
GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \
|
||||
test_name)); \
|
||||
}; \
|
||||
int GTEST_TEST_CLASS_NAME_(test_case_name, \
|
||||
test_name)::gtest_registering_dummy_ = \
|
||||
GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \
|
||||
void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody()
|
||||
|
||||
#define XLA_TEST_P(test_case_name, test_name) \
|
||||
XLA_TEST_P_IMPL_(test_case_name, test_name)
|
||||
|
||||
// This is identical to the TEST_F macro from "gtest", but it potentially
|
||||
// disables the test based on an external manifest file, DISABLED_MANIFEST.
|
||||
#define XLA_TYPED_TEST(CaseName, TestName) \
|
||||
template <typename gtest_TypeParam_> \
|
||||
class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \
|
||||
: public CaseName<gtest_TypeParam_> { \
|
||||
private: \
|
||||
typedef CaseName<gtest_TypeParam_> TestFixture; \
|
||||
typedef gtest_TypeParam_ TypeParam; \
|
||||
virtual void TestBody(); \
|
||||
}; \
|
||||
bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \
|
||||
::testing::internal::TypeParameterizedTest< \
|
||||
CaseName, \
|
||||
::testing::internal::TemplateSel<GTEST_TEST_CLASS_NAME_(CaseName, \
|
||||
TestName)>, \
|
||||
GTEST_TYPE_PARAMS_(CaseName)>:: \
|
||||
Register( \
|
||||
"", ::testing::internal::CodeLocation(__FILE__, __LINE__), \
|
||||
#CaseName, \
|
||||
::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \
|
||||
0); \
|
||||
template <typename gtest_TypeParam_> \
|
||||
void GTEST_TEST_CLASS_NAME_(CaseName, \
|
||||
TestName)<gtest_TypeParam_>::TestBody()
|
||||
#define XLA_TYPED_TEST(CaseName, TestName) TYPED_TEST(CaseName, TestName)
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
|
||||
|
|
|
@ -719,6 +719,7 @@ tf_cuda_library(
|
|||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/core/util:__pkg__",
|
||||
"//tensorflow/security/fuzzing:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":allocation_description_proto_cc",
|
||||
|
|
|
@ -153,16 +153,9 @@ limitations under the License.
|
|||
#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines
|
||||
|
||||
// Defines for sets of types.
|
||||
|
||||
// TODO(b/111604096): Add uint32 and uint64 to TF_CALL_INTEGRAL_TYPES.
|
||||
//
|
||||
// The uint32 and uint64 types were introduced in 10/2017 to be used via XLA and
|
||||
// thus were not included in TF_CALL_INTEGRAL_TYPES. Including them in
|
||||
// TF_CALL_INTEGRAL_TYPES should only happen after evaluating the effect on the
|
||||
// TF binary size and performance.
|
||||
#define TF_CALL_INTEGRAL_TYPES(m) \
|
||||
TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
|
||||
TF_CALL_uint8(m) TF_CALL_int8(m)
|
||||
#define TF_CALL_INTEGRAL_TYPES(m) \
|
||||
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_int32(m) \
|
||||
TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
|
||||
|
||||
#define TF_CALL_FLOAT_TYPES(m) \
|
||||
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
|
||||
|
@ -174,10 +167,10 @@ limitations under the License.
|
|||
#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
|
||||
TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
|
||||
|
||||
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
|
||||
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
|
||||
TF_CALL_int64(m) TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) \
|
||||
TF_CALL_int8(m)
|
||||
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
|
||||
TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \
|
||||
TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \
|
||||
TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
|
||||
|
||||
#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
|
||||
|
||||
|
|
|
@ -238,11 +238,6 @@ int DataTypeSize(DataType dt) {
|
|||
TF_CALL_qint16(CASE);
|
||||
TF_CALL_quint16(CASE);
|
||||
|
||||
// uint32 and uint64 aren't included in TF_CALL_POD_TYPES because we
|
||||
// don't want to define kernels for them at this stage to avoid binary
|
||||
// bloat.
|
||||
TF_CALL_uint32(CASE);
|
||||
TF_CALL_uint64(CASE);
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -837,6 +837,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
|||
"ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
|
||||
"Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
|
||||
"RandomPoissonV2",
|
||||
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
||||
|
||||
// ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
|
||||
// but it can't generate any observable side-effect.
|
||||
|
@ -850,12 +851,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
|
|||
// the same device_ordinal on the same host.
|
||||
"EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
|
||||
"EnqueueTPUEmbeddingSparseTensorBatch",
|
||||
"EnqueueTPUEmbeddingRaggedTensorBatch",
|
||||
|
||||
// SaveV2 and RestoreV2 should be allowed to operate in parallel on
|
||||
// multiple hosts.
|
||||
"SaveV2", "RestoreV2"});
|
||||
// LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
|
||||
"EnqueueTPUEmbeddingRaggedTensorBatch"});
|
||||
return exemption->contains(op);
|
||||
}
|
||||
|
||||
|
|
|
@ -4168,6 +4168,25 @@ tf_kernel_library(
|
|||
]),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "mlir_generated_op_gpu_tanh_test",
|
||||
size = "small",
|
||||
srcs = if_mlir_generated_gpu_kernels_enabled(["mlir_generated_op_gpu_tanh_test.cc"]),
|
||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
||||
deps = [
|
||||
":cwise_op",
|
||||
":ops_testutil",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/common_runtime:device",
|
||||
"//tensorflow/core/common_runtime:device_factory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "nextafter_op",
|
||||
prefix = "nextafter_op",
|
||||
|
@ -4900,7 +4919,9 @@ tf_kernel_library(
|
|||
"topk_op_gpu_double.cu.cc",
|
||||
"topk_op_gpu_float.cu.cc",
|
||||
"topk_op_gpu_half.cu.cc",
|
||||
"topk_op_gpu_uint64.cu.cc",
|
||||
"topk_op_gpu_int64.cu.cc",
|
||||
"topk_op_gpu_uint32.cu.cc",
|
||||
"topk_op_gpu_int32.cu.cc",
|
||||
"topk_op_gpu_int16.cu.cc",
|
||||
"topk_op_gpu_uint16.cu.cc",
|
||||
|
@ -6802,7 +6823,8 @@ filegroup(
|
|||
"cwise_op_minimum.cc",
|
||||
"cwise_op_mul_1.cc",
|
||||
"cwise_op_mul_2.cc",
|
||||
"cwise_op_neg.cc",
|
||||
"cwise_op_neg_1.cc",
|
||||
"cwise_op_neg_2.cc",
|
||||
"cwise_op_pow.cc",
|
||||
"cwise_op_real.cc",
|
||||
"cwise_op_reciprocal.cc",
|
||||
|
@ -8780,7 +8802,8 @@ exports_files([
|
|||
"cwise_op_mod.cc",
|
||||
"cwise_op_mul_1.cc",
|
||||
"cwise_op_mul_2.cc",
|
||||
"cwise_op_neg.cc",
|
||||
"cwise_op_neg_1.cc",
|
||||
"cwise_op_neg_2.cc",
|
||||
"cwise_op_not_equal_to_1.cc",
|
||||
"cwise_op_not_equal_to_2.cc",
|
||||
"cwise_op_round.cc",
|
||||
|
|
|
@ -116,8 +116,6 @@ REGISTER(qint8)
|
|||
REGISTER(quint16)
|
||||
REGISTER(qint16)
|
||||
REGISTER(qint32)
|
||||
REGISTER(uint32)
|
||||
REGISTER(uint64)
|
||||
|
||||
#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
|
||||
!defined(__ANDROID_TYPES_FULL__)
|
||||
|
|
|
@ -208,8 +208,6 @@ REGISTER_CONCAT(qint8);
|
|||
REGISTER_CONCAT(quint16);
|
||||
REGISTER_CONCAT(qint16);
|
||||
REGISTER_CONCAT(qint32);
|
||||
REGISTER_CONCAT(uint32);
|
||||
REGISTER_CONCAT(uint64);
|
||||
|
||||
#undef REGISTER_CONCAT
|
||||
|
||||
|
|
|
@ -211,7 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
|
|||
// the conversion from uint8 to quint8.
|
||||
REGISTER_KERNEL(CPU, quint8);
|
||||
REGISTER_KERNEL(CPU, quint16);
|
||||
REGISTER_KERNEL(CPU, uint32);
|
||||
#undef REGISTER_CPU_KERNEL
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
|
|
@ -101,27 +101,21 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_SWITCH);
|
|||
TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
|
||||
REGISTER_CPU_SWITCH(uint64);
|
||||
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
|
||||
REGISTER_GPU_SWITCH(uint64);
|
||||
TF_CALL_variant(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_uint32(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_uint32(REGISTER_GPU_REF_SWITCH);
|
||||
TF_CALL_bool(REGISTER_GPU_SWITCH);
|
||||
TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
|
||||
|
||||
#undef REGISTER_CPU_SWITCH
|
||||
#undef REGISTER_CPU_REF_SWITCH
|
||||
#undef REGISTER_GPU_SWITCH
|
||||
#undef REGISTER_GPU_REF_SWITCH
|
||||
|
||||
// Special GPU kernels for int32, string & resource handles. Requiring all
|
||||
// inputs and outputs to be in host memory.
|
||||
// TODO(b/25387198): Also enable int32 in device memory.
|
||||
// Special GPU kernels for int32 and string.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
||||
// registration requires all int32 inputs and outputs to be in host memory.
|
||||
#define REGISTER_GPU_HOST_KERNEL(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Switch") \
|
||||
.Device(DEVICE_GPU) \
|
||||
|
@ -151,6 +145,8 @@ TF_CALL_bool(REGISTER_GPU_REF_SWITCH);
|
|||
|
||||
REGISTER_GPU_HOST_KERNEL(int32);
|
||||
REGISTER_GPU_HOST_REF_KERNEL(int32);
|
||||
REGISTER_GPU_HOST_KERNEL(bool);
|
||||
REGISTER_GPU_HOST_REF_KERNEL(bool);
|
||||
REGISTER_GPU_HOST_KERNEL(tstring);
|
||||
REGISTER_GPU_HOST_REF_KERNEL(tstring);
|
||||
REGISTER_GPU_HOST_KERNEL(ResourceHandle);
|
||||
|
@ -311,7 +307,6 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_KERNEL);
|
|||
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
|
||||
REGISTER_GPU_KERNEL(bool);
|
||||
REGISTER_GPU_REF_KERNEL(bool);
|
||||
REGISTER_GPU_KERNEL(uint64);
|
||||
TF_CALL_variant(REGISTER_GPU_KERNEL);
|
||||
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
|
|
@ -19,8 +19,8 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
DEFINE_UNARY7(neg, Eigen::half, float, double, int32, int64, complex64,
|
||||
complex128);
|
||||
DEFINE_UNARY4(neg, int8, int16, int32, int64);
|
||||
DEFINE_UNARY6(neg, Eigen::half, float, double, bfloat16, complex64, complex128);
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
@ -16,8 +16,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER8(UnaryOp, CPU, "Neg", functor::neg, float, Eigen::half, double, int32,
|
||||
complex64, int64, complex128, bfloat16);
|
||||
REGISTER4(UnaryOp, CPU, "Neg", functor::neg, int8, int16, int32, int64);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER3(UnaryOp, SYCL, "Neg", functor::neg, float, double, int64);
|
||||
|
@ -30,8 +29,7 @@ REGISTER_KERNEL_BUILDER(Name("Neg")
|
|||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, float, Eigen::half, double, int64,
|
||||
complex64, complex128);
|
||||
REGISTER3(UnaryOp, GPU, "Neg", functor::neg, int8, int16, int64);
|
||||
|
||||
// A special GPU kernel for int32.
|
||||
// TODO(b/25387198): Also enable int32 in device memory. This kernel
|
|
@ -0,0 +1,26 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER6(UnaryOp, CPU, "Neg", functor::neg, Eigen::half, float, double,
|
||||
bfloat16, complex64, complex128);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER6(UnaryOp, GPU, "Neg", functor::neg, Eigen::half, float, double,
|
||||
bfloat16, complex64, complex128);
|
||||
#endif
|
||||
} // namespace tensorflow
|
|
@ -220,8 +220,6 @@ Status DatasetOpsTestBase::ExpectEqual(const Tensor& a, const Tensor& b) {
|
|||
break;
|
||||
TF_CALL_NUMBER_TYPES(CASE);
|
||||
TF_CALL_tstring(CASE);
|
||||
TF_CALL_uint32(CASE);
|
||||
TF_CALL_uint64(CASE);
|
||||
// TODO(feihugis): figure out how to support variant tensors.
|
||||
#undef CASE
|
||||
default:
|
||||
|
|
|
@ -98,7 +98,6 @@ typedef Eigen::SyclDevice SYCLDevice;
|
|||
|
||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
// uint32 not included in ALL_TYPES
|
||||
TF_CALL_uint32(REGISTER_KERNELS);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||
// quint16 not included in QUANTIZIED_TYPES
|
||||
TF_CALL_quint16(REGISTER_KERNELS);
|
||||
|
|
|
@ -164,8 +164,6 @@ class DynamicPartitionOp : public DynamicPartitionOp_Shared {
|
|||
DynamicPartitionOp<T>)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_PARTITION);
|
||||
// For partitioning fingerprints.
|
||||
TF_CALL_uint64(REGISTER_DYNAMIC_PARTITION);
|
||||
#undef REGISTER_DYNAMIC_PARTITION
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -45,6 +45,8 @@ DEFINE_SETZERO_CPU(Eigen::half);
|
|||
DEFINE_SETZERO_CPU(bfloat16);
|
||||
DEFINE_SETZERO_CPU(float);
|
||||
DEFINE_SETZERO_CPU(double);
|
||||
DEFINE_SETZERO_CPU(uint32);
|
||||
DEFINE_SETZERO_CPU(uint64);
|
||||
DEFINE_SETZERO_CPU(uint8);
|
||||
DEFINE_SETZERO_CPU(int8);
|
||||
DEFINE_SETZERO_CPU(uint16);
|
||||
|
@ -96,6 +98,8 @@ DEFINE_SETONE_CPU(Eigen::half);
|
|||
DEFINE_SETONE_CPU(bfloat16);
|
||||
DEFINE_SETONE_CPU(float);
|
||||
DEFINE_SETONE_CPU(double);
|
||||
DEFINE_SETONE_CPU(uint32);
|
||||
DEFINE_SETONE_CPU(uint64);
|
||||
DEFINE_SETONE_CPU(uint8);
|
||||
DEFINE_SETONE_CPU(int8);
|
||||
DEFINE_SETONE_CPU(uint16);
|
||||
|
@ -137,7 +141,6 @@ struct FillFunctor<Eigen::ThreadPoolDevice, T> {
|
|||
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
|
||||
DEFINE_FILL_CPU(quint8);
|
||||
DEFINE_FILL_CPU(quint16);
|
||||
DEFINE_FILL_CPU(uint32);
|
||||
#undef DEFINE_FILL_CPU
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
|
|
@ -211,8 +211,6 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
|
|||
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
|
||||
TF_CALL_quint16(REGISTER_GATHER_CPU);
|
||||
TF_CALL_qint16(REGISTER_GATHER_CPU);
|
||||
TF_CALL_uint32(REGISTER_GATHER_CPU);
|
||||
TF_CALL_uint64(REGISTER_GATHER_CPU);
|
||||
|
||||
#undef REGISTER_GATHER_CPU
|
||||
|
||||
|
|
|
@ -122,7 +122,6 @@ REGISTER_SYCL_HOST_KERNEL(bool);
|
|||
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
|
||||
REGISTER_GPU_KERNEL(Variant);
|
||||
TF_CALL_uint32(REGISTER_GPU_KERNEL);
|
||||
|
||||
#undef REGISTER_GPU_KERNEL
|
||||
|
||||
|
|
|
@ -178,6 +178,9 @@ class MklAddNOp : public OpKernel {
|
|||
dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
|
||||
}
|
||||
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(ctx, cpu_engine));
|
||||
|
||||
// Create memory descriptor for MKL-DNN.
|
||||
// If all input in Tensorflow format, create block memory descriptor,
|
||||
// else convert TF format to MKL memory descriptor
|
||||
|
@ -215,6 +218,7 @@ class MklAddNOp : public OpKernel {
|
|||
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
|
||||
#endif
|
||||
src.SetUsrMem(md, &src_tensor);
|
||||
src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream);
|
||||
inputs.push_back(src.GetOpMem());
|
||||
}
|
||||
|
||||
|
@ -240,11 +244,10 @@ class MklAddNOp : public OpKernel {
|
|||
}
|
||||
AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
|
||||
output_mkl_shape);
|
||||
dst.SetUsrMemDataHandle(dst_tensor);
|
||||
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
|
||||
|
||||
// Create Sum op, and submit net for execution.
|
||||
std::vector<primitive> net;
|
||||
stream* fwd_cpu_stream = CreateStream(ctx, cpu_engine);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
mkldnn::sum sum_op(sum_pd);
|
||||
std::unordered_map<int, memory> net_args = {
|
||||
|
|
|
@ -281,11 +281,19 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
|||
std::shared_ptr<stream> fwd_stream) {
|
||||
DCHECK_EQ(in_data.size(), context_.data_mem.size());
|
||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.data_mem_shdptr[i]->set_data_handle(
|
||||
static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream);
|
||||
}
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(dst_data.get_data_handle()), *fwd_stream);
|
||||
#else
|
||||
context_.data_mem_shdptr[i]->set_data_handle(
|
||||
static_cast<void*>(in_data[i].get_data_handle()));
|
||||
}
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(dst_data.get_data_handle()));
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
|
||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||
context_.data_mem[i] = *context_.data_mem_shdptr[i];
|
||||
|
@ -788,11 +796,13 @@ class MklConcatOp : public OpKernel {
|
|||
dnn_shape_dst);
|
||||
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
|
||||
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
|
||||
|
||||
if (dnn_shape_dst.IsMklTensor())
|
||||
dst_md = dnn_shape_dst.GetMklLayout();
|
||||
dst.SetUsrMem(dst_md, dst_tensor);
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
|
||||
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto concat_op = concat(concat_pd);
|
||||
std::unordered_map<int, memory> net_args = {
|
||||
|
@ -830,9 +840,10 @@ class MklConcatOp : public OpKernel {
|
|||
|
||||
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
|
||||
: dst_md;
|
||||
dst.SetUsrMem(dst_md, dst_tensor);
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
|
||||
dst.SetUsrMem(dst_md, dst_tensor);
|
||||
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
|
||||
// Execute concat
|
||||
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
|
||||
fwd_cpu_stream);
|
||||
|
|
|
@ -75,6 +75,9 @@ class MklDequantizeOp : public OpKernel {
|
|||
MklDnnData<T> src(&cpu_engine);
|
||||
MklDnnData<float> dst(&cpu_engine);
|
||||
|
||||
std::shared_ptr<stream> reorder_stream;
|
||||
reorder_stream.reset(CreateStream(ctx, cpu_engine));
|
||||
|
||||
// If input is in MKL layout, then simply grab input layout; otherwise,
|
||||
// construct input TF layout. For TF layout, although input shape
|
||||
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
|
||||
|
@ -85,6 +88,7 @@ class MklDequantizeOp : public OpKernel {
|
|||
: memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc);
|
||||
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.SetUsrMemDataHandle(&src_tensor, reorder_stream);
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
MklDnnShape output_mkl_shape;
|
||||
|
@ -129,6 +133,7 @@ class MklDequantizeOp : public OpKernel {
|
|||
AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape,
|
||||
output_mkl_shape);
|
||||
dst.SetUsrMem(dst_md, output_tensor);
|
||||
dst.SetUsrMemDataHandle(output_tensor, reorder_stream);
|
||||
|
||||
// The quantization logic here for mode SCALED is similar to the logic
|
||||
// in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
|
||||
|
@ -155,8 +160,6 @@ class MklDequantizeOp : public OpKernel {
|
|||
// Also it does not define round_nearest (enum).
|
||||
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
std::shared_ptr<stream> reorder_stream;
|
||||
reorder_stream.reset(CreateStream(ctx, cpu_engine));
|
||||
std::vector<primitive> net;
|
||||
|
||||
// Create reorder primitive and then execute.
|
||||
|
|
|
@ -137,6 +137,7 @@ class MklLRNOp : public OpKernel {
|
|||
// that input is in NHWC layout with Channel being the last dimension.
|
||||
src_dnn_data.SetUsrMem(src_md, &src_tensor);
|
||||
src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
|
||||
src_dnn_data.SetUsrMemDataHandle(&src_tensor, fwd_stream_);
|
||||
|
||||
// dst_dnn_data has the same shape as input.
|
||||
dst_dnn_data.SetUsrMem(src_md);
|
||||
|
@ -157,7 +158,7 @@ class MklLRNOp : public OpKernel {
|
|||
&output_tensor);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
DCHECK(output_tensor != nullptr);
|
||||
dst_dnn_data.SetUsrMemDataHandle(output_tensor);
|
||||
dst_dnn_data.SetUsrMemDataHandle(output_tensor, fwd_stream_);
|
||||
|
||||
// Handle workspace required for MKL-DNN.
|
||||
AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
|
||||
|
@ -393,6 +394,7 @@ class MklLRNGradOp : public OpKernel {
|
|||
orig_input_dnn_shape.GetSizesAsMklDnnDims();
|
||||
orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
|
||||
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
|
||||
orig_input_dnn_data.SetUsrMemDataHandle(&orig_input_tensor, bwd_stream_);
|
||||
|
||||
// output_dnn_data has the same shape as original input
|
||||
output_dnn_data.SetUsrMem(orig_input_md);
|
||||
|
@ -421,7 +423,7 @@ class MklLRNGradOp : public OpKernel {
|
|||
orig_input_format, &output_tensor);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
DCHECK(output_tensor != nullptr);
|
||||
output_dnn_data.SetUsrMemDataHandle(output_tensor);
|
||||
output_dnn_data.SetUsrMemDataHandle(output_tensor, bwd_stream_);
|
||||
|
||||
// Create LRN primitive and add it to the net
|
||||
// At this point, workspace is enabled, so we don't need
|
||||
|
|
|
@ -137,6 +137,7 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
|||
memory::dims out_strides =
|
||||
ReorderStrides(CalculateTFStrides(out_dims), perm);
|
||||
|
||||
std::shared_ptr<stream> transpose_stream;
|
||||
in.SetUsrMem(in_dims, in_strides, &in_tensor);
|
||||
// Output dimensions are same as input dimensions. We adjust the layout
|
||||
// using strides.
|
||||
|
@ -144,16 +145,16 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
|||
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::shared_ptr<stream> transpose_stream;
|
||||
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
|
||||
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||
in.SetUsrMemDataHandle(&in_tensor, transpose_stream);
|
||||
out.SetUsrMemDataHandle(out_tensor, transpose_stream);
|
||||
net.push_back(*(prim->GetPrimitive()));
|
||||
std::vector<MemoryArgsMap> net_args;
|
||||
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
|
||||
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
|
||||
execute_primitives(net, transpose_stream, net_args);
|
||||
#else
|
||||
std::shared_ptr<stream> transpose_stream;
|
||||
transpose_stream.reset(new CPU_STREAM(cpu_engine));
|
||||
net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()));
|
||||
transpose_stream->submit(net).wait();
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class MlirGeneratedOpGpuTanhTest : public OpsTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::unique_ptr<tensorflow::Device> device_gpu(
|
||||
tensorflow::DeviceFactory::NewDevice("GPU", {},
|
||||
"/job:a/replica:0/task:0"));
|
||||
SetDevice(tensorflow::DEVICE_GPU, std::move(device_gpu));
|
||||
}
|
||||
template <typename T, typename RT = T>
|
||||
void RunTanhOp(std::initializer_list<T> input) {
|
||||
TensorShape shape({2, 7});
|
||||
TF_ASSERT_OK(NodeDefBuilder("tanh_op", "Tanh")
|
||||
.Input(FakeInput(DataTypeToEnum<T>::v()))
|
||||
.Attr("T", DataTypeToEnum<T>::v())
|
||||
.Finalize(node_def()));
|
||||
|
||||
TF_ASSERT_OK(InitOp());
|
||||
AddInputFromArray<T>(shape, input);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
|
||||
std::vector<T> expected;
|
||||
expected.reserve(input.size());
|
||||
for (const T& inp : input) {
|
||||
expected.push_back(static_cast<T>(std::tanh(static_cast<RT>(inp))));
|
||||
}
|
||||
test::FillValues<T>(&expected_tensor, expected);
|
||||
test::ExpectClose(expected_tensor, *GetOutput(0));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MlirGeneratedOpGpuTanhTest, TanhFloat) {
|
||||
RunTanhOp<float>({-18.0f, -9.0f, -1e-6f, -0.0f, 0.0f, 1e-6, 0.1f, 0.2f, 0.3f,
|
||||
0.5f, 0.7f, 0.9f, 9.0f, 18.0f});
|
||||
}
|
||||
|
||||
TEST_F(MlirGeneratedOpGpuTanhTest, TanhDouble) {
|
||||
RunTanhOp<double>({-18.0, -9.0, -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5,
|
||||
0.7, 0.9, 9.0, 18.0});
|
||||
}
|
||||
|
||||
TEST_F(MlirGeneratedOpGpuTanhTest, TanhHalf) {
|
||||
RunTanhOp<Eigen::half, float>(
|
||||
{static_cast<Eigen::half>(-18.0), static_cast<Eigen::half>(-9.0),
|
||||
static_cast<Eigen::half>(-1e-6), static_cast<Eigen::half>(-0.0),
|
||||
static_cast<Eigen::half>(0.0), static_cast<Eigen::half>(1e-6),
|
||||
static_cast<Eigen::half>(0.1), static_cast<Eigen::half>(0.2),
|
||||
static_cast<Eigen::half>(0.3), static_cast<Eigen::half>(0.5),
|
||||
static_cast<Eigen::half>(0.7), static_cast<Eigen::half>(0.9),
|
||||
static_cast<Eigen::half>(9.0), static_cast<Eigen::half>(18.0)});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // end namespace tensorflow
|
|
@ -296,8 +296,6 @@ TF_CALL_tstring(REGISTER_CPU_KERNEL);
|
|||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_quint16(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_qint16(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_uint32(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_uint64(REGISTER_CPU_KERNEL);
|
||||
#undef REGISTER_CPU_KERNEL
|
||||
#undef REGISTER_CPU_KERNEL_WITH_INDEX_TYPE
|
||||
|
||||
|
|
|
@ -308,8 +308,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
|
|||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||
TF_CALL_quint16(REGISTER_KERNELS);
|
||||
TF_CALL_qint16(REGISTER_KERNELS);
|
||||
TF_CALL_uint32(REGISTER_KERNELS);
|
||||
TF_CALL_uint64(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -561,8 +561,6 @@ TF_CALL_string(REGISTER_CPU_KERNEL);
|
|||
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_quint16(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_qint16(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_uint32(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_uint64(REGISTER_CPU_KERNEL);
|
||||
|
||||
#undef REGISTER_CPU_KERNEL
|
||||
|
||||
|
|
|
@ -213,8 +213,6 @@ TF_CALL_tstring(REGISTER_KERNELS);
|
|||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||
TF_CALL_quint16(REGISTER_KERNELS);
|
||||
TF_CALL_qint16(REGISTER_KERNELS);
|
||||
TF_CALL_uint32(REGISTER_KERNELS);
|
||||
TF_CALL_uint64(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
#undef REGISTER_KERNELS_WITH_SPLIT_TYPE
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -35,6 +35,7 @@ namespace tensorflow {
|
|||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
static constexpr int VectorSizeElements = 8;
|
||||
namespace functor {
|
||||
|
||||
// This kernel computes ReluGrad by processing one half2, two fp16, at a time.
|
||||
|
@ -93,6 +94,66 @@ __global__ void ReluGradHalfKernel(const Eigen::half* __restrict__ gradient,
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void ReluGradHalfKernelVector(
|
||||
const Eigen::half* __restrict__ gradient,
|
||||
const Eigen::half* __restrict__ feature, Eigen::half* __restrict__ backprop,
|
||||
int32 count) {
|
||||
int32 half8_count = count / VectorSizeElements;
|
||||
int32 index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (index < half8_count) {
|
||||
// Cast to xx_h8 for vector load and store.
|
||||
float4 gradient_h8 = reinterpret_cast<const float4*>(gradient)[index];
|
||||
float4 feature_h8 = reinterpret_cast<const float4*>(feature)[index];
|
||||
float4* p_backprop_h8 = reinterpret_cast<float4*>(backprop) + index;
|
||||
|
||||
half2* gradient_h2 = reinterpret_cast<half2*>(&gradient_h8);
|
||||
half2* feature_h2 = reinterpret_cast<half2*>(&feature_h8);
|
||||
float4 backprop_h8;
|
||||
half2* p_backprop_h2 = reinterpret_cast<half2*>(&backprop_h8);
|
||||
|
||||
// Fast path, when half2 primitives are available.
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
const half2 kZeroH2 = __float2half2_rn(0.f);
|
||||
#endif
|
||||
for (int i = 0; i < VectorSizeElements / 2; i++) {
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
// mask = (feature > 0)
|
||||
half2 mask_h2 = __hgt2(feature_h2[i], kZeroH2);
|
||||
// backprop = mask * gradient
|
||||
half2 backprop_h2 = __hmul2(mask_h2, gradient_h2[i]);
|
||||
#else
|
||||
// Fall back: convert half2 to float2 for processing.
|
||||
float2 feature_f2 = __half22float2(feature_h2[i]);
|
||||
float2 gradient_f2 = __half22float2(gradient_h2[i]);
|
||||
float2 backprop_f2 =
|
||||
make_float2((feature_f2.x > 0.0f) ? float(gradient_f2.x) : 0.0f,
|
||||
(feature_f2.y > 0.0f) ? float(gradient_f2.y) : 0.0f);
|
||||
// Convert back to half2.
|
||||
half2 backprop_h2 = __float22half2_rn(backprop_f2);
|
||||
#endif
|
||||
p_backprop_h2[i] = backprop_h2;
|
||||
}
|
||||
// Write back the result.
|
||||
*p_backprop_h8 = backprop_h8;
|
||||
}
|
||||
|
||||
int remaining_count = (count % VectorSizeElements);
|
||||
|
||||
if (index < remaining_count) {
|
||||
// Use first threads to process the remaining elements.
|
||||
Eigen::half grad_h = gradient[half8_count * VectorSizeElements + index];
|
||||
Eigen::half feature_h = feature[half8_count * VectorSizeElements + index];
|
||||
|
||||
float grad_f = static_cast<float>(grad_h);
|
||||
float feature_f = static_cast<float>(feature_h);
|
||||
float backprop_f = (feature_f > 0) ? grad_f : 0;
|
||||
|
||||
Eigen::half backprop_h(backprop_f);
|
||||
backprop[half8_count * VectorSizeElements + index] = backprop_h;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
struct ReluGrad<Device, Eigen::half> {
|
||||
// Computes ReluGrad backprop.
|
||||
|
@ -108,15 +169,28 @@ struct ReluGrad<Device, Eigen::half> {
|
|||
// NOTE: When the activation is exactly zero, we do not propagate the
|
||||
// associated gradient value. This allows the output of the Relu to be used,
|
||||
// as well as its input.
|
||||
auto gradient_ptr = reinterpret_cast<uintptr_t>(gradient.data());
|
||||
auto feature_ptr = reinterpret_cast<uintptr_t>(feature.data());
|
||||
auto backprop_ptr = reinterpret_cast<uintptr_t>(backprop.data());
|
||||
bool aligned = gradient_ptr % 16 == 0 && feature_ptr % 16 == 0 &&
|
||||
backprop_ptr % 16 == 0;
|
||||
int32 count = gradient.size();
|
||||
if (count == 0) return;
|
||||
int32 half2_count = Eigen::divup(count, 2);
|
||||
constexpr int32 kThreadInBlock = 512;
|
||||
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
||||
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
|
||||
if (count == 0) return;
|
||||
if (aligned) {
|
||||
int32 half8_count = Eigen::divup(count, VectorSizeElements);
|
||||
int32 kBlock = Eigen::divup(half8_count, kThreadInBlock);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ReluGradHalfKernelVector, kBlock, kThreadInBlock, 0, d.stream(),
|
||||
gradient.data(), feature.data(), backprop.data(), count));
|
||||
} else {
|
||||
int32 half2_count = Eigen::divup(count, 2);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize(
|
||||
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
ReluGradHalfKernel, config.block_count, config.thread_per_block, 0,
|
||||
d.stream(), gradient.data(), feature.data(), backprop.data(), count));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -512,7 +512,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
|
|||
|
||||
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
|
||||
TF_CALL_uint32(REGISTER_KERNELS);
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
|
@ -43,7 +43,6 @@ void Split<Eigen::ThreadPoolDevice, T, NDims>::operator()(
|
|||
|
||||
TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS)
|
||||
DEFINE_CPU_KERNELS(quint8)
|
||||
DEFINE_CPU_KERNELS(uint64)
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename T, int NDims>
|
||||
|
|
|
@ -404,7 +404,6 @@ class SplitOpSYCL : public SplitOpBase<SYCLDevice, T> {
|
|||
|
||||
TF_CALL_ALL_TYPES(REGISTER_SPLIT);
|
||||
REGISTER_SPLIT(quint8);
|
||||
REGISTER_SPLIT(uint64);
|
||||
|
||||
#undef REGISTER_SPLIT
|
||||
|
||||
|
|
|
@ -440,8 +440,6 @@ class StridedSliceAssignOp : public OpKernel {
|
|||
StridedSliceAssignOp<CPUDevice, type, true>)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
|
||||
TF_CALL_uint32(REGISTER_STRIDED_SLICE);
|
||||
TF_CALL_uint64(REGISTER_STRIDED_SLICE);
|
||||
|
||||
#undef REGISTER_STRIDED_SLICE
|
||||
|
||||
|
|
|
@ -287,8 +287,6 @@ TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU);
|
|||
#endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
|
||||
TF_CALL_uint32(DECLARE_FOR_N_CPU);
|
||||
TF_CALL_uint64(DECLARE_FOR_N_CPU);
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
#define PREVENT_FOR_N_SYCL(T) \
|
||||
|
|
|
@ -52,7 +52,8 @@ class SummaryScalarOp : public OpKernel {
|
|||
Summary s;
|
||||
for (int i = 0; i < Ttags.size(); i++) {
|
||||
Summary::Value* v = s.add_value();
|
||||
v->set_tag(string(Ttags(i))); // NOLINT
|
||||
const tstring& Ttags_i = Ttags(i);
|
||||
v->set_tag(Ttags_i.data(), Ttags_i.size());
|
||||
v->set_simple_value(float(Tvalues(i)));
|
||||
}
|
||||
|
||||
|
@ -102,7 +103,8 @@ class SummaryHistoOp : public OpKernel {
|
|||
|
||||
Summary s;
|
||||
Summary::Value* v = s.add_value();
|
||||
v->set_tag(string(tags.scalar<tstring>()())); // NOLINT
|
||||
const tstring& tags0 = tags.scalar<tstring>()();
|
||||
v->set_tag(tags0.data(), tags0.size());
|
||||
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
|
||||
|
||||
Tensor* summary_tensor = nullptr;
|
||||
|
|
|
@ -258,7 +258,6 @@ namespace functor {
|
|||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_INTEGRAL_TYPES(DECLARE_GPU_SPEC);
|
||||
TF_CALL_uint32(DECLARE_GPU_SPEC);
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
|
||||
|
@ -276,7 +275,6 @@ TF_CALL_uint32(DECLARE_GPU_SPEC);
|
|||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
|
||||
TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
|
||||
TF_CALL_uint32(REGISTER_KERNELS)
|
||||
#undef REGISTER_KERNELS
|
||||
|
||||
#endif // end GOOGLE_CUDA
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
#include "tensorflow/core/kernels/topk_op_gpu.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using Eigen::GpuDevice;
|
||||
|
||||
template struct functor::TopKFunctor<GPUDevice, uint32>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
|
@ -0,0 +1,28 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
#include "tensorflow/core/kernels/topk_op_gpu.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using Eigen::GpuDevice;
|
||||
|
||||
template struct functor::TopKFunctor<GPUDevice, uint64>;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
|
@ -252,7 +252,8 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL);
|
|||
|
||||
TF_CALL_int64(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_uint32(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
|
||||
TF_CALL_COMPLEX_TYPES(REGISTER_GPU_KERNELS);
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
|
|
@ -78,3 +78,32 @@ op {
|
|||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Acos"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT8
|
||||
type: DT_INT16
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,3 +78,32 @@ op {
|
|||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Asin"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT8
|
||||
type: DT_INT16
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,3 +78,32 @@ op {
|
|||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Atan"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT8
|
||||
type: DT_INT16
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -168,3 +168,32 @@ op {
|
|||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Inv"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT8
|
||||
type: DT_INT16
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue