Build scatter HLO op with MlirHloBuilder
Here, XlaComputation is imported as a region for the scatter HLO op in MLIR. This is similar to XlaBuilder that constructs instructions by coping the given computation. Another approach for MlirHloBuilder could be to directly build the MLIR region to avoid the extra conversion from XlaComputation to MLIR. This is theoretically possible but would require changes to construct a new sub XlaBuilder from the existing one instead of from scratch. This way the same MLIRContext can be used for the new one and ops can be moved between the two freely. Even with this, we may still need to construct HloModuleProto in cases to support XlaComputation methods like proto and Snapshot that returns protos. These changes will complicate conversions so it is better to keep the fallback path simpler and have some overhead. It should still have performance similar to using XlaBuilder. Whitelisted SparseToDense op for testing. PiperOrigin-RevId: 316207774 Change-Id: I7016f150c9d1aa514ff9ede1f69baf545b6da6aa
This commit is contained in:
parent
d09085241a
commit
1fbc648c9c
@ -660,6 +660,7 @@ cc_library(
|
||||
deps = [
|
||||
":attribute_importer",
|
||||
":hlo",
|
||||
":hlo_module_importer",
|
||||
":hlo_utils",
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
|
@ -19,10 +19,12 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
@ -140,6 +142,24 @@ StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
|
||||
bool unique_indices) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::ScatterOp>(
|
||||
loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
|
||||
ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
|
||||
builder_.getBoolAttr(indices_are_sorted),
|
||||
builder_.getBoolAttr(unique_indices));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(update_computation.proto(), &op.update_computation()));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
|
||||
RandomDistribution distribution, absl::Span<const XlaOp> parameters,
|
||||
const Shape& shape) {
|
||||
@ -348,6 +368,18 @@ StatusOr<XlaOp> MlirHloBuilder::CreateOp(
|
||||
return MakeXlaOp(op->getResult(0));
|
||||
}
|
||||
|
||||
Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
|
||||
mlir::Region* region) {
|
||||
TF_ASSIGN_OR_RETURN(auto module_config,
|
||||
xla::HloModule::CreateModuleConfigFromProto(
|
||||
computation, xla::DebugOptions()));
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto(
|
||||
computation, module_config));
|
||||
|
||||
return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
|
||||
region, &builder_);
|
||||
}
|
||||
|
||||
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
|
||||
TF_RETURN_IF_ERROR(first_error());
|
||||
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
|
||||
|
@ -129,6 +129,12 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
absl::Span<const int64> slice_sizes, bool indices_are_sorted) override;
|
||||
|
||||
StatusOr<XlaOp> ScatterInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
|
||||
bool unique_indices) override;
|
||||
|
||||
StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
|
||||
absl::Span<const XlaOp> parameters,
|
||||
const Shape& shape) override;
|
||||
@ -196,6 +202,9 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
llvm::ArrayRef<XlaOp> operands,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attributes = {});
|
||||
|
||||
Status ImportComputation(const HloModuleProto& computation,
|
||||
mlir::Region* region);
|
||||
|
||||
mlir::OpBuilder builder_;
|
||||
mlir::Location loc_;
|
||||
|
||||
|
@ -187,6 +187,33 @@ func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2
|
||||
return %0: tensor<3x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sparse_to_dense
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor<f32>)
|
||||
func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>) -> tensor<3x3xf32> {
|
||||
|
||||
// CHECK: %[[CST:.*]] = xla_hlo.constant dense<3> : tensor<2xi32>
|
||||
// CHECK: %[[DEFAULT:.*]] = "xla_hlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): // no predecessors
|
||||
// CHECK: "xla_hlo.return"(%[[ARG4]]) : (tensor<f32>) -> ()
|
||||
// CHECK: })
|
||||
// CHECK-SAME: indices_are_sorted = false
|
||||
// CHECK-SAME: scatter_dimension_numbers
|
||||
// CHECK-SAME: index_vector_dim = 1 : i64
|
||||
// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: update_window_dims = dense<[]> : tensor<0xi64>
|
||||
// CHECK-SAME: unique_indices = false
|
||||
// CHECK-SAME: (tensor<3x3xf32>, tensor<3x2xi32>, tensor<3xf32>) -> tensor<3x3xf32>
|
||||
|
||||
// return %[[RESULT]] : tensor<3x3xf32>
|
||||
|
||||
%cst = xla_hlo.constant dense<3> : tensor<2xi32>
|
||||
%0 = "tf.SparseToDense"(%arg0, %cst, %arg1, %arg2) {validate_indices = true}: (tensor<3x2xi32>, tensor<2xi32>, tensor<3xf32>, tensor<f32>) -> tensor<3x3xf32>
|
||||
return %0 : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
@ -156,6 +156,7 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::SoftplusGradOp>(),
|
||||
TypeID::get<TF::SoftsignGradOp>(),
|
||||
TypeID::get<TF::SoftsignOp>(),
|
||||
TypeID::get<TF::SparseToDenseOp>(),
|
||||
TypeID::get<TF::SqrtGradOp>(),
|
||||
TypeID::get<TF::SquareOp>(),
|
||||
TypeID::get<TF::SubOp>(),
|
||||
|
@ -1218,6 +1218,7 @@ tf_xla_py_test(
|
||||
name = "sparse_to_dense_op_test",
|
||||
size = "small",
|
||||
srcs = ["sparse_to_dense_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -101,6 +102,7 @@ class SparseToDenseTest(xla_test.XLATestCase):
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
|
||||
_SparseToDense([1, 3], [[5], [3]], 1, -1)
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testBadValue(self):
|
||||
with self.session(), self.test_scope():
|
||||
with self.assertRaisesOpError(
|
||||
@ -108,12 +110,14 @@ class SparseToDenseTest(xla_test.XLATestCase):
|
||||
r"should be \[\] or \[2\]"):
|
||||
_SparseToDense([1, 3], [5], [[5], [3]], -1)
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testBadNumValues(self):
|
||||
with self.session(), self.test_scope():
|
||||
with self.assertRaisesOpError(
|
||||
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
|
||||
_SparseToDense([1, 3], [5], [1, 2, 3], -1)
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testBadDefault(self):
|
||||
with self.session(), self.test_scope():
|
||||
with self.assertRaisesOpError("default_value should be a scalar"):
|
||||
|
@ -1899,11 +1899,6 @@ XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
const ScatterDimensionNumbers& dimension_numbers,
|
||||
bool indices_are_sorted, bool unique_indices) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
instr.set_indices_are_sorted(indices_are_sorted);
|
||||
|
||||
instr.set_unique_indices(unique_indices);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const Shape* input_shape, GetShapePtr(input));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* scatter_indices_shape,
|
||||
GetShapePtr(scatter_indices));
|
||||
@ -1914,8 +1909,22 @@ XlaOp XlaBuilder::Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
Shape shape, ShapeInference::InferScatterShape(
|
||||
*input_shape, *scatter_indices_shape, *updates_shape,
|
||||
to_apply_shape, dimension_numbers));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return ScatterInternal(shape, input, scatter_indices, updates,
|
||||
update_computation, dimension_numbers,
|
||||
indices_are_sorted, unique_indices);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::ScatterInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
|
||||
bool unique_indices) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
instr.set_indices_are_sorted(indices_are_sorted);
|
||||
instr.set_unique_indices(unique_indices);
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
*instr.mutable_scatter_dimension_numbers() = dimension_numbers;
|
||||
|
||||
AddCalledComputation(update_computation, &instr);
|
||||
|
@ -653,6 +653,12 @@ class XlaBuilder {
|
||||
const ScatterDimensionNumbers& dimension_numbers,
|
||||
bool indices_are_sorted = false, bool unique_indices = false);
|
||||
|
||||
virtual StatusOr<XlaOp> ScatterInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
|
||||
bool unique_indices);
|
||||
|
||||
void Send(XlaOp operand, const ChannelHandle& handle);
|
||||
XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user