Introduce Xla{Dot,Conv}V2
These support mixed operand precision (bf16 @ f32 -> f32), (int8 @ int8 -> int32) and thus diverge significantly from the original Xla{Dot,Conv}. PiperOrigin-RevId: 358860920 Change-Id: I316f1b43268b268ce50528f7aa307182bce304f6
This commit is contained in:
parent
85c597aaa2
commit
4856f23a49
@ -694,8 +694,10 @@ tensorflow::MemoryTypeVector GetOutputMemoryTypes(
|
||||
static auto const ops_triggering_xla_compilation =
|
||||
new absl::flat_hash_set<std::string>{"XlaBroadcastHelper",
|
||||
"XlaConv",
|
||||
"XlaConvV2",
|
||||
"XlaDequantize",
|
||||
"XlaDot",
|
||||
"XlaDotV2",
|
||||
"XlaDynamicSlice",
|
||||
"XlaDynamicUpdateSlice",
|
||||
"XlaEinsum",
|
||||
|
@ -2057,8 +2057,10 @@ absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
"While",
|
||||
"XlaBroadcastHelper",
|
||||
"XlaConv",
|
||||
"XlaConvV2",
|
||||
"XlaDequantize",
|
||||
"XlaDot",
|
||||
"XlaDotV2",
|
||||
"XlaDynamicSlice",
|
||||
"XlaDynamicUpdateSlice",
|
||||
"XlaEinsum",
|
||||
|
@ -29,6 +29,14 @@ namespace tensorflow {
|
||||
.HostMemory("feature_group_count") \
|
||||
.Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaConvV2") \
|
||||
.HostMemory("window_strides") \
|
||||
.HostMemory("padding") \
|
||||
.HostMemory("lhs_dilation") \
|
||||
.HostMemory("rhs_dilation") \
|
||||
.HostMemory("feature_group_count") \
|
||||
.Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
@ -38,6 +46,8 @@ namespace tensorflow {
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("XlaDotV2").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("XlaDynamicSlice").HostMemory("size_indices").Device(DEVICE), \
|
||||
XlaCompileOnDemandOp); \
|
||||
|
@ -17682,6 +17682,37 @@ https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaConvV2Op : TF_Op<"XlaConvV2", [NoSideEffect]> {
|
||||
let summary = "Wraps the XLA ConvGeneralDilated operator, documented at";
|
||||
|
||||
let description = [{
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
|
||||
.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the input tensor}]>:$lhs,
|
||||
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the kernel tensor}]>:$rhs,
|
||||
Arg<TF_I32OrI64Tensor, [{the inter-window strides}]>:$window_strides,
|
||||
Arg<TF_I32OrI64Tensor, [{the padding to apply at the start and end of each input dimensions}]>:$padding,
|
||||
Arg<TF_I32OrI64Tensor, [{dilation to apply between input elements}]>:$lhs_dilation,
|
||||
Arg<TF_I32OrI64Tensor, [{dilation to apply between kernel elements}]>:$rhs_dilation,
|
||||
Arg<TF_I32OrI64Tensor, [{number of feature groups for grouped convolution.}]>:$feature_group_count,
|
||||
|
||||
StrAttr:$dimension_numbers,
|
||||
StrAttr:$precision_config
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<2>;
|
||||
TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr RhsT = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedResultTypeAttr preferred_element_type = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaDotOp : TF_Op<"XlaDot", [NoSideEffect]> {
|
||||
let summary = "Wraps the XLA DotGeneral operator, documented at";
|
||||
|
||||
@ -17705,6 +17736,31 @@ https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaDotV2Op : TF_Op<"XlaDotV2", [NoSideEffect]> {
|
||||
let summary = "Wraps the XLA DotGeneral operator, documented at";
|
||||
|
||||
let description = [{
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
|
||||
.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the LHS tensor}]>:$lhs,
|
||||
Arg<TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>, [{the RHS tensor}]>:$rhs,
|
||||
|
||||
StrAttr:$dimension_numbers,
|
||||
StrAttr:$precision_config
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Bfloat16, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr LhsT = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr RhsT = TF_DerivedOperandTypeAttr<1>;
|
||||
TF_DerivedResultTypeAttr preferred_element_type = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_XlaDynamicSliceOp : TF_Op<"XlaDynamicSlice", [NoSideEffect]> {
|
||||
let summary = "Wraps the XLA DynamicSlice operator, documented at";
|
||||
|
||||
|
@ -264,7 +264,9 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::UpperBoundOp>(),
|
||||
TypeID::get<TF::XlaBroadcastHelperOp>(),
|
||||
TypeID::get<TF::XlaConvOp>(),
|
||||
TypeID::get<TF::XlaConvV2Op>(),
|
||||
TypeID::get<TF::XlaDotOp>(),
|
||||
TypeID::get<TF::XlaDotV2Op>(),
|
||||
TypeID::get<TF::XlaDynamicSliceOp>(),
|
||||
TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
|
||||
TypeID::get<TF::XlaEinsumOp>(),
|
||||
|
@ -205,6 +205,35 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
],
|
||||
dtype=dtype))
|
||||
|
||||
def testDotGeneralInt8xInt8ToInt32(self):
|
||||
|
||||
def dot_fn(lhs, rhs):
|
||||
dnums = xla_data_pb2.DotDimensionNumbers()
|
||||
dnums.lhs_contracting_dimensions.append(2)
|
||||
dnums.rhs_contracting_dimensions.append(1)
|
||||
dnums.lhs_batch_dimensions.append(0)
|
||||
dnums.rhs_batch_dimensions.append(0)
|
||||
return xla.dot_general(
|
||||
lhs, rhs, dimension_numbers=dnums, preferred_element_type=np.int32)
|
||||
|
||||
lhs = np.array([
|
||||
[[1, 2], [3, 4]],
|
||||
[[5, 6], [7, 8]],
|
||||
], dtype=np.int8)
|
||||
rhs = np.array([
|
||||
[[1, 2, 3], [4, 5, 6]],
|
||||
[[7, 8, 9], [10, 11, 12]],
|
||||
],
|
||||
dtype=np.int8)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
dot_fn,
|
||||
args=(lhs, rhs),
|
||||
expected=np.array([
|
||||
[[9, 12, 15], [19, 26, 33]],
|
||||
[[95, 106, 117], [129, 144, 159]],
|
||||
],
|
||||
dtype=np.int32))
|
||||
|
||||
def testNeg(self):
|
||||
for dtype in self.numeric_types - {np.uint8, np.int8}:
|
||||
self._assertOpOutputMatchesExpected(
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -38,6 +39,7 @@ class XlaConvOp : public XlaOpKernel {
|
||||
OP_REQUIRES(context,
|
||||
precision_config_.ParsePartialFromString(precision_config_attr),
|
||||
errors::InvalidArgument("Error parsing precision config."));
|
||||
preferred_element_type_ = absl::nullopt;
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
@ -77,10 +79,13 @@ class XlaConvOp : public XlaOpKernel {
|
||||
xla::XlaOp output = xla::ConvGeneralDilated(
|
||||
context->Input(0), context->Input(1), window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dnums_, feature_group_count,
|
||||
/*batch_group_count=*/1, &precision_config_);
|
||||
/*batch_group_count=*/1, &precision_config_, preferred_element_type_);
|
||||
context->SetOutput(0, output);
|
||||
}
|
||||
|
||||
protected:
|
||||
absl::optional<xla::PrimitiveType> preferred_element_type_;
|
||||
|
||||
private:
|
||||
xla::ConvolutionDimensionNumbers dnums_;
|
||||
xla::PrecisionConfig precision_config_;
|
||||
@ -96,5 +101,29 @@ REGISTER_XLA_OP(Name("XlaConv")
|
||||
.CompileTimeConstantInput("padding"),
|
||||
XlaConvOp);
|
||||
|
||||
class XlaConvV2Op : public XlaConvOp {
|
||||
public:
|
||||
explicit XlaConvV2Op(OpKernelConstruction* context) : XlaConvOp(context) {
|
||||
DataType preferred_element_dtype;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("preferred_element_type",
|
||||
&preferred_element_dtype));
|
||||
xla::PrimitiveType preferred_element_type;
|
||||
OP_REQUIRES_OK(context, DataTypeToPrimitiveType(preferred_element_dtype,
|
||||
&preferred_element_type));
|
||||
preferred_element_type_ = preferred_element_type;
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaConvV2Op);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaConvV2")
|
||||
.CompileTimeConstantInput("window_strides")
|
||||
.CompileTimeConstantInput("lhs_dilation")
|
||||
.CompileTimeConstantInput("rhs_dilation")
|
||||
.CompileTimeConstantInput("feature_group_count")
|
||||
.CompileTimeConstantInput("padding"),
|
||||
XlaConvOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -14,12 +14,14 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -39,6 +41,7 @@ class XlaDotOp : public XlaOpKernel {
|
||||
context,
|
||||
precision_config_.ParsePartialFromString(precision_config_attr),
|
||||
errors::InvalidArgument("Error parsing convolution dimension numbers"));
|
||||
preferred_element_type_ = absl::nullopt;
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
@ -47,19 +50,40 @@ class XlaDotOp : public XlaOpKernel {
|
||||
|
||||
// We do only minimal checking, relying on XLA to check the shape
|
||||
// invariants.
|
||||
xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1),
|
||||
dnums_, &precision_config_);
|
||||
xla::XlaOp output =
|
||||
xla::DotGeneral(context->Input(0), context->Input(1), dnums_,
|
||||
&precision_config_, preferred_element_type_);
|
||||
context->SetOutput(0, output);
|
||||
}
|
||||
|
||||
protected:
|
||||
absl::optional<xla::PrimitiveType> preferred_element_type_;
|
||||
|
||||
private:
|
||||
xla::DotDimensionNumbers dnums_;
|
||||
xla::PrecisionConfig precision_config_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp);
|
||||
|
||||
class XlaDotV2Op : public XlaDotOp {
|
||||
public:
|
||||
explicit XlaDotV2Op(OpKernelConstruction* context) : XlaDotOp(context) {
|
||||
DataType preferred_element_dtype;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("preferred_element_type",
|
||||
&preferred_element_dtype));
|
||||
xla::PrimitiveType preferred_element_type;
|
||||
OP_REQUIRES_OK(context, DataTypeToPrimitiveType(preferred_element_dtype,
|
||||
&preferred_element_type));
|
||||
preferred_element_type_ = preferred_element_type;
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaDotV2Op);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("XlaDotV2"), XlaDotV2Op);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -161,6 +161,147 @@ dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
|
||||
precision_config: a serialized xla::PrecisionConfig proto.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaConvV2")
|
||||
.Input("lhs: LhsT")
|
||||
.Input("rhs: RhsT")
|
||||
.Input("window_strides: Tindices")
|
||||
.Input("padding: Tindices")
|
||||
.Input("lhs_dilation: Tindices")
|
||||
.Input("rhs_dilation: Tindices")
|
||||
.Input("feature_group_count: Tindices")
|
||||
.Attr("LhsT: numbertype")
|
||||
.Attr("RhsT: numbertype")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.Attr("dimension_numbers: string")
|
||||
.Attr("precision_config: string")
|
||||
.Attr("preferred_element_type: numbertype")
|
||||
.Output("output: preferred_element_type")
|
||||
.SetShapeFn(UnchangedRank)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA ConvGeneralDilated operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
|
||||
.
|
||||
|
||||
lhs: the input tensor
|
||||
rhs: the kernel tensor
|
||||
window_strides: the inter-window strides
|
||||
padding: the padding to apply at the start and end of each input dimensions
|
||||
lhs_dilation: dilation to apply between input elements
|
||||
rhs_dilation: dilation to apply between kernel elements
|
||||
feature_group_count: number of feature groups for grouped convolution.
|
||||
dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto.
|
||||
precision_config: a serialized xla::PrecisionConfig proto.
|
||||
preferred_element_type: The type of the tensor.
|
||||
)doc");
|
||||
|
||||
static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
|
||||
shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
|
||||
if (!c->FullyDefined(lhs_shape_handle) ||
|
||||
!c->FullyDefined(rhs_shape_handle)) {
|
||||
return shape_inference::UnknownShape(c);
|
||||
}
|
||||
|
||||
string dimension_numbers_string;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->GetAttr("dimension_numbers", &dimension_numbers_string));
|
||||
|
||||
xla::DotDimensionNumbers dimension_numbers;
|
||||
dimension_numbers.ParseFromString(dimension_numbers_string);
|
||||
|
||||
// Check that number of contracting dimensions match.
|
||||
if (dimension_numbers.lhs_contracting_dimensions_size() !=
|
||||
dimension_numbers.rhs_contracting_dimensions_size())
|
||||
return errors::InvalidArgument(
|
||||
"Must specify the same number of contracting dimensions for lhs "
|
||||
"and rhs. Got: ",
|
||||
dimension_numbers.lhs_contracting_dimensions_size(), " and ",
|
||||
dimension_numbers.rhs_contracting_dimensions_size());
|
||||
|
||||
// Check that contracting dimension sizes match.
|
||||
for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
|
||||
++i) {
|
||||
const int64 lhs_contracting_dimension =
|
||||
dimension_numbers.lhs_contracting_dimensions(i);
|
||||
const int64 rhs_contracting_dimension =
|
||||
dimension_numbers.rhs_contracting_dimensions(i);
|
||||
shape_inference::DimensionOrConstant lhs_contracting_dimension_or_constant(
|
||||
c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension));
|
||||
shape_inference::DimensionOrConstant rhs_contracting_dimension_or_constant(
|
||||
c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension));
|
||||
const int64 lhs_contracting_dimension_size =
|
||||
c->Value(lhs_contracting_dimension_or_constant);
|
||||
const int64 rhs_contracting_dimension_size =
|
||||
c->Value(rhs_contracting_dimension_or_constant);
|
||||
if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) {
|
||||
return errors::InvalidArgument(
|
||||
"Contracting dimension sizes do not match. Got: ",
|
||||
lhs_contracting_dimension_size, " and ",
|
||||
rhs_contracting_dimension_size);
|
||||
}
|
||||
}
|
||||
|
||||
// Check that number of batch dimensions match.
|
||||
if (dimension_numbers.lhs_batch_dimensions_size() !=
|
||||
dimension_numbers.rhs_batch_dimensions_size())
|
||||
return errors::InvalidArgument(
|
||||
"Must specify the same number of batch dimensions for lhs "
|
||||
"and rhs. Got: ",
|
||||
dimension_numbers.lhs_batch_dimensions_size(), " and ",
|
||||
dimension_numbers.rhs_batch_dimensions_size());
|
||||
|
||||
// Check that batch dimension sizes match.
|
||||
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
|
||||
const int64 lhs_batch_dimension = dimension_numbers.lhs_batch_dimensions(i);
|
||||
const int64 rhs_batch_dimension = dimension_numbers.rhs_batch_dimensions(i);
|
||||
shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant(
|
||||
c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension));
|
||||
shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant(
|
||||
c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension));
|
||||
const int64 lhs_batch_dimension_size =
|
||||
c->Value(lhs_batch_dimension_or_constant);
|
||||
const int64 rhs_batch_dimension_size =
|
||||
c->Value(rhs_batch_dimension_or_constant);
|
||||
if (lhs_batch_dimension_size != rhs_batch_dimension_size) {
|
||||
return errors::InvalidArgument(
|
||||
"Batch dimension sizes do not match. Got: ", lhs_batch_dimension_size,
|
||||
" and ", rhs_batch_dimension_size);
|
||||
}
|
||||
}
|
||||
|
||||
// The ranks of lhs and rhs are decremented by 1 respectively due to the
|
||||
// contraction, and added for the rank of the result. When an input tensor
|
||||
// is a scalar, its contribution to the rank of the result is 0. Generate
|
||||
// the result dimensions in order, rhs dimensions followed by lhs
|
||||
// dimensions except the contracted and batch dimensions.
|
||||
std::vector<shape_inference::DimensionHandle> output_dims;
|
||||
for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
|
||||
output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim));
|
||||
}
|
||||
const int32 lhs_rank = c->Rank(lhs_shape_handle);
|
||||
for (int64 i = 0; i < lhs_rank; ++i) {
|
||||
if (absl::c_linear_search(dimension_numbers.lhs_contracting_dimensions(),
|
||||
i) ||
|
||||
absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(), i)) {
|
||||
continue;
|
||||
}
|
||||
output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
|
||||
}
|
||||
|
||||
const int32 rhs_rank = c->Rank(rhs_shape_handle);
|
||||
for (int64 i = 0; i < rhs_rank; ++i) {
|
||||
if (absl::c_linear_search(dimension_numbers.rhs_contracting_dimensions(),
|
||||
i) ||
|
||||
absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(), i)) {
|
||||
continue;
|
||||
}
|
||||
output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
|
||||
}
|
||||
|
||||
c->set_output(0, c->MakeShape(output_dims));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_OP("XlaDot")
|
||||
.Input("lhs: T")
|
||||
.Input("rhs: T")
|
||||
@ -168,120 +309,7 @@ REGISTER_OP("XlaDot")
|
||||
.Attr("dimension_numbers: string")
|
||||
.Attr("precision_config: string")
|
||||
.Output("output: T")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle lhs_shape_handle = c->input(0);
|
||||
shape_inference::ShapeHandle rhs_shape_handle = c->input(1);
|
||||
if (!c->FullyDefined(lhs_shape_handle) ||
|
||||
!c->FullyDefined(rhs_shape_handle)) {
|
||||
return shape_inference::UnknownShape(c);
|
||||
}
|
||||
|
||||
string dimension_numbers_string;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->GetAttr("dimension_numbers", &dimension_numbers_string));
|
||||
|
||||
xla::DotDimensionNumbers dimension_numbers;
|
||||
dimension_numbers.ParseFromString(dimension_numbers_string);
|
||||
|
||||
// Check that number of contracting dimensions match.
|
||||
if (dimension_numbers.lhs_contracting_dimensions_size() !=
|
||||
dimension_numbers.rhs_contracting_dimensions_size())
|
||||
return errors::InvalidArgument(
|
||||
"Must specify the same number of contracting dimensions for lhs "
|
||||
"and rhs. Got: ",
|
||||
dimension_numbers.lhs_contracting_dimensions_size(), " and ",
|
||||
dimension_numbers.rhs_contracting_dimensions_size());
|
||||
|
||||
// Check that contracting dimension sizes match.
|
||||
for (int64 i = 0; i < dimension_numbers.lhs_contracting_dimensions_size();
|
||||
++i) {
|
||||
const int64 lhs_contracting_dimension =
|
||||
dimension_numbers.lhs_contracting_dimensions(i);
|
||||
const int64 rhs_contracting_dimension =
|
||||
dimension_numbers.rhs_contracting_dimensions(i);
|
||||
shape_inference::DimensionOrConstant
|
||||
lhs_contracting_dimension_or_constant(
|
||||
c->DimKnownRank(lhs_shape_handle, lhs_contracting_dimension));
|
||||
shape_inference::DimensionOrConstant
|
||||
rhs_contracting_dimension_or_constant(
|
||||
c->DimKnownRank(rhs_shape_handle, rhs_contracting_dimension));
|
||||
const int64 lhs_contracting_dimension_size =
|
||||
c->Value(lhs_contracting_dimension_or_constant);
|
||||
const int64 rhs_contracting_dimension_size =
|
||||
c->Value(rhs_contracting_dimension_or_constant);
|
||||
if (lhs_contracting_dimension_size != rhs_contracting_dimension_size) {
|
||||
return errors::InvalidArgument(
|
||||
"Contracting dimension sizes do not match. Got: ",
|
||||
lhs_contracting_dimension_size, " and ",
|
||||
rhs_contracting_dimension_size);
|
||||
}
|
||||
}
|
||||
|
||||
// Check that number of batch dimensions match.
|
||||
if (dimension_numbers.lhs_batch_dimensions_size() !=
|
||||
dimension_numbers.rhs_batch_dimensions_size())
|
||||
return errors::InvalidArgument(
|
||||
"Must specify the same number of batch dimensions for lhs "
|
||||
"and rhs. Got: ",
|
||||
dimension_numbers.lhs_batch_dimensions_size(), " and ",
|
||||
dimension_numbers.rhs_batch_dimensions_size());
|
||||
|
||||
// Check that batch dimension sizes match.
|
||||
for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size();
|
||||
++i) {
|
||||
const int64 lhs_batch_dimension =
|
||||
dimension_numbers.lhs_batch_dimensions(i);
|
||||
const int64 rhs_batch_dimension =
|
||||
dimension_numbers.rhs_batch_dimensions(i);
|
||||
shape_inference::DimensionOrConstant lhs_batch_dimension_or_constant(
|
||||
c->DimKnownRank(lhs_shape_handle, lhs_batch_dimension));
|
||||
shape_inference::DimensionOrConstant rhs_batch_dimension_or_constant(
|
||||
c->DimKnownRank(rhs_shape_handle, rhs_batch_dimension));
|
||||
const int64 lhs_batch_dimension_size =
|
||||
c->Value(lhs_batch_dimension_or_constant);
|
||||
const int64 rhs_batch_dimension_size =
|
||||
c->Value(rhs_batch_dimension_or_constant);
|
||||
if (lhs_batch_dimension_size != rhs_batch_dimension_size) {
|
||||
return errors::InvalidArgument(
|
||||
"Batch dimension sizes do not match. Got: ",
|
||||
lhs_batch_dimension_size, " and ", rhs_batch_dimension_size);
|
||||
}
|
||||
}
|
||||
|
||||
// The ranks of lhs and rhs are decremented by 1 respectively due to the
|
||||
// contraction, and added for the rank of the result. When an input tensor
|
||||
// is a scalar, its contribution to the rank of the result is 0. Generate
|
||||
// the result dimensions in order, rhs dimensions followed by lhs
|
||||
// dimensions except the contracted and batch dimensions.
|
||||
std::vector<shape_inference::DimensionHandle> output_dims;
|
||||
for (int64 lhs_dim : dimension_numbers.lhs_batch_dimensions()) {
|
||||
output_dims.emplace_back(c->Dim(lhs_shape_handle, lhs_dim));
|
||||
}
|
||||
const int32 lhs_rank = c->Rank(lhs_shape_handle);
|
||||
for (int64 i = 0; i < lhs_rank; ++i) {
|
||||
if (absl::c_linear_search(
|
||||
dimension_numbers.lhs_contracting_dimensions(), i) ||
|
||||
absl::c_linear_search(dimension_numbers.lhs_batch_dimensions(),
|
||||
i)) {
|
||||
continue;
|
||||
}
|
||||
output_dims.emplace_back(c->Dim(lhs_shape_handle, i));
|
||||
}
|
||||
|
||||
const int32 rhs_rank = c->Rank(rhs_shape_handle);
|
||||
for (int64 i = 0; i < rhs_rank; ++i) {
|
||||
if (absl::c_linear_search(
|
||||
dimension_numbers.rhs_contracting_dimensions(), i) ||
|
||||
absl::c_linear_search(dimension_numbers.rhs_batch_dimensions(),
|
||||
i)) {
|
||||
continue;
|
||||
}
|
||||
output_dims.emplace_back(c->Dim(rhs_shape_handle, i));
|
||||
}
|
||||
|
||||
c->set_output(0, c->MakeShape(output_dims));
|
||||
return Status::OK();
|
||||
})
|
||||
.SetShapeFn(XlaDotShapeFunction)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA DotGeneral operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
|
||||
@ -293,6 +321,28 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto.
|
||||
precision_config: a serialized xla::PrecisionConfig proto.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaDotV2")
|
||||
.Input("lhs: LhsT")
|
||||
.Input("rhs: RhsT")
|
||||
.Attr("LhsT: numbertype")
|
||||
.Attr("RhsT: numbertype")
|
||||
.Attr("dimension_numbers: string")
|
||||
.Attr("precision_config: string")
|
||||
.Attr("preferred_element_type: numbertype")
|
||||
.Output("output: preferred_element_type")
|
||||
.SetShapeFn(XlaDotShapeFunction)
|
||||
.Doc(R"doc(
|
||||
Wraps the XLA DotGeneral operator, documented at
|
||||
https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral
|
||||
.
|
||||
|
||||
lhs: the LHS tensor
|
||||
rhs: the RHS tensor
|
||||
dimension_numbers: a serialized xla::DotDimensionNumbers proto.
|
||||
precision_config: a serialized xla::PrecisionConfig proto.
|
||||
preferred_element_type: The type of the tensor.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("XlaSetBound")
|
||||
.Input("input: int32")
|
||||
.Input("bound: int32")
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.ops import gen_random_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
from tensorflow.python.ops.numpy_ops import np_utils
|
||||
|
||||
# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
|
||||
# ops include:
|
||||
@ -249,6 +250,7 @@ def conv(lhs,
|
||||
dimension_numbers,
|
||||
feature_group_count=1,
|
||||
precision_config=None,
|
||||
preferred_element_type=None,
|
||||
name=None):
|
||||
"""Wraps the XLA ConvGeneralDilated operator.
|
||||
|
||||
@ -266,6 +268,7 @@ def conv(lhs,
|
||||
dimension_numbers: a `ConvolutionDimensionNumbers` proto.
|
||||
feature_group_count: number of feature groups for grouped convolution.
|
||||
precision_config: a `xla.PrecisionConfig` proto.
|
||||
preferred_element_type: the result `dtype`.
|
||||
name: an optional name for the operator
|
||||
|
||||
Returns:
|
||||
@ -274,7 +277,9 @@ def conv(lhs,
|
||||
precision_config_proto = ""
|
||||
if precision_config:
|
||||
precision_config_proto = precision_config.SerializeToString()
|
||||
return gen_xla_ops.xla_conv(
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
|
||||
return gen_xla_ops.xla_conv_v2(
|
||||
lhs,
|
||||
rhs,
|
||||
window_strides=window_strides,
|
||||
@ -284,6 +289,7 @@ def conv(lhs,
|
||||
feature_group_count=feature_group_count,
|
||||
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||
precision_config=precision_config_proto,
|
||||
preferred_element_type=preferred_element_type,
|
||||
name=name)
|
||||
|
||||
|
||||
@ -294,15 +300,23 @@ def dot(lhs, rhs, name=None):
|
||||
return math_ops.tensordot(lhs, rhs, axes=1, name=name)
|
||||
|
||||
|
||||
def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
|
||||
def dot_general(lhs,
|
||||
rhs,
|
||||
dimension_numbers,
|
||||
precision_config=None,
|
||||
preferred_element_type=None,
|
||||
name=None):
|
||||
precision_config_proto = ""
|
||||
if precision_config:
|
||||
precision_config_proto = precision_config.SerializeToString()
|
||||
return gen_xla_ops.xla_dot(
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
|
||||
return gen_xla_ops.xla_dot_v2(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||
precision_config=precision_config_proto,
|
||||
preferred_element_type=preferred_element_type,
|
||||
name=name)
|
||||
|
||||
|
||||
|
@ -143,7 +143,8 @@ bool IsOpTensorCoreEligible(absl::string_view tf_op_name) {
|
||||
|| absl::StrContains(tf_op_name, "CudnnRNNForward")
|
||||
|| absl::StrContains(tf_op_name, "CudnnRNNBackprop")
|
||||
// Special cases.
|
||||
|| absl::EndsWith(tf_op_name, "XlaDot");
|
||||
|| absl::EndsWith(tf_op_name, "XlaDot")
|
||||
|| absl::EndsWith(tf_op_name, "XlaDotV2");
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user