Add legalizations for RngReadAndSkip and StatelessRandomUniformFullIntV2.
PiperOrigin-RevId: 357766922 Change-Id: I6f3806d714cbeb62e5deafe580c9641b1f0a9e60
This commit is contained in:
parent
e2cd9cda80
commit
3af7f6d31f
tensorflow/compiler
mlir
tensorflow
xla/transforms
tests
@ -698,6 +698,7 @@ cc_library(
|
||||
":decompose_resource_ops_inc_gen",
|
||||
":tensorflow",
|
||||
":tensorflow_types",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
@ -12646,6 +12646,27 @@ def TF_RiscDotOp : TF_Op<"RiscDot", [NoSideEffect]> {
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RngReadAndSkipOp : TF_Op<"RngReadAndSkip", []> {
|
||||
let summary = "Advance the counter of a counter-based RNG.";
|
||||
|
||||
let description = [{
|
||||
The state of the RNG after
|
||||
`rng_read_and_skip(n)` will be the same as that after `uniform([n])`
|
||||
(or any other distribution). The actual increment added to the
|
||||
counter is an unspecified implementation choice.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$resource,
|
||||
TF_Int32Tensor:$alg,
|
||||
TF_Uint64Tensor:$delta
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Int64Tensor:$value
|
||||
);
|
||||
}
|
||||
|
||||
def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
|
||||
let summary = "Rolls the elements of a tensor along an axis.";
|
||||
|
||||
@ -15011,6 +15032,32 @@ The outputs are a deterministic function of `shape` and `seed`.
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomUniformFullIntV2Op : TF_Op<"StatelessRandomUniformFullIntV2", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The generated values are uniform integers covering the whole range of `dtype`.
|
||||
|
||||
The outputs are a deterministic function of `shape`, `key`, `counter` and `alg`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TF_I32OrI64Tensor, [{The shape of the output tensor.}]>:$shape,
|
||||
Arg<TF_Uint64Tensor, [{Key for the counter-based RNG algorithm (shape uint64[1]).}]>:$key,
|
||||
Arg<TF_Uint64Tensor, [{Initial counter for the counter-based RNG algorithm (shape uint64[2] or uint64[1] depending on the algorithm). If a larger vector is given, only the needed portion on the left (i.e. [:N]) will be used.}]>:$counter,
|
||||
Arg<TF_Int32Tensor, [{The RNG algorithm (shape int32[]).}]>:$alg
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Res<TensorOf<[TF_Int32, TF_Int64, TF_Uint32, TF_Uint64]>, [{Random values with specified shape.}]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tshape = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect, TF_NoConstantFold]> {
|
||||
let summary = [{
|
||||
Outputs deterministic pseudorandom random integers from a uniform distribution.
|
||||
|
@ -568,3 +568,17 @@ func @decompose_resource_apply_proximal_adagrad_op(%lr: tensor<f32>, %l1: tensor
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test that tf.RngReadAndSkip op is decomposed.
|
||||
// CHECK-LABEL: func @decompose_rng_read_and_skip_op
|
||||
func @decompose_rng_read_and_skip_op(%resource: tensor<!tf.resource<tensor<3xi64>>>) -> tensor<3xi64> {
|
||||
// We rely on the TensorFlow StatefulRandomOpsTest to check it is lowered
|
||||
// correctly.
|
||||
// CHECK-NOT: tf.RngReadAndSkip
|
||||
%alg = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%delta = "tf.Const"() {value = dense<10> : tensor<ui64>} : () -> tensor<ui64>
|
||||
%0 = "tf.RngReadAndSkip"(%resource, %alg, %delta) : (tensor<!tf.resource<tensor<3xi64>>>, tensor<i32>, tensor<ui64>) -> tensor<3xi64>
|
||||
return %0 : tensor<3xi64>
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/core/framework/rng_alg.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
@ -68,11 +69,152 @@ static Type GetResourceSubtype(Value resource) {
|
||||
.front();
|
||||
}
|
||||
|
||||
// Decompose tf.RngReadAndSkip.
|
||||
//
|
||||
// For Philox, the resource variable holds a tensor<3xi64> with the state:
|
||||
// [counter_lo, counter_hi, key]
|
||||
//
|
||||
// RngReadAndSkip increments the 128 bit counter value by 256 * delta and
|
||||
// returns the original state value.
|
||||
//
|
||||
// For Threefry, the resource variable holds a tensor<2xi64> with the state:
|
||||
// [counter, key]
|
||||
//
|
||||
// RngReadAndSkip increments the 64 bit counter value by 256 * delta and
|
||||
// returns a tensor<3xi64> value [counter, key, 0].
|
||||
class DecomposeRngReadAndSkipOp : public RewritePattern {
|
||||
public:
|
||||
explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
|
||||
: RewritePattern(RngReadAndSkipOp::getOperationName(),
|
||||
{
|
||||
AddV2Op::getOperationName(),
|
||||
AssignVariableOp::getOperationName(),
|
||||
CastOp::getOperationName(),
|
||||
ConstOp::getOperationName(),
|
||||
LessOp::getOperationName(),
|
||||
MulOp::getOperationName(),
|
||||
PadOp::getOperationName(),
|
||||
PackOp::getOperationName(),
|
||||
ReadVariableOp::getOperationName(),
|
||||
SelectV2Op::getOperationName(),
|
||||
UnpackOp::getOperationName(),
|
||||
},
|
||||
1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto rng_op = cast<RngReadAndSkipOp>(op);
|
||||
|
||||
DenseIntElementsAttr alg_constant;
|
||||
if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
|
||||
op->emitOpError() << "unable to determine algorithm statically";
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (alg_constant.getNumElements() != 1) {
|
||||
op->emitOpError() << "expected alg to be a scalar";
|
||||
return failure();
|
||||
}
|
||||
|
||||
uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
|
||||
tensorflow::Algorithm alg;
|
||||
if (tensorflow::RNG_ALG_PHILOX == alg_value) {
|
||||
alg = tensorflow::RNG_ALG_PHILOX;
|
||||
} else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
|
||||
alg = tensorflow::RNG_ALG_THREEFRY;
|
||||
} else {
|
||||
op->emitOpError() << "unsupported alg";
|
||||
return failure();
|
||||
}
|
||||
|
||||
Type state_element_type = rewriter.getI64Type();
|
||||
RankedTensorType op_type = RankedTensorType::get(
|
||||
{tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
|
||||
state_element_type);
|
||||
if (op_type != rng_op.getType()) {
|
||||
op->emitOpError() << "unexpected op type";
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!HasResourceSubtype(rng_op.resource())) {
|
||||
op->emitOpError() << "missing resource subtype";
|
||||
return failure();
|
||||
}
|
||||
|
||||
int counter_size = tensorflow::GetCounterSize(alg);
|
||||
int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
|
||||
RankedTensorType res_type =
|
||||
RankedTensorType::get({state_size}, state_element_type);
|
||||
if (res_type != GetResourceSubtype(rng_op.resource())) {
|
||||
op->emitOpError() << "unexpected resource subtype";
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = op->getLoc();
|
||||
|
||||
// Read the state value from the resource.
|
||||
Value state =
|
||||
rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
|
||||
|
||||
// Extract the key and counter from the state.
|
||||
RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
|
||||
auto unpacked = rewriter.create<UnpackOp>(
|
||||
loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
|
||||
Value key = unpacked.getResult(counter_size);
|
||||
|
||||
SmallVector<Value, 4> counter;
|
||||
for (int i = 0; i < counter_size; ++i) {
|
||||
counter.push_back(unpacked.getResult(i));
|
||||
}
|
||||
|
||||
// Set the increment to 256 * delta.
|
||||
Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
|
||||
RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
|
||||
Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
|
||||
Value increment =
|
||||
rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
|
||||
|
||||
// Increment the counter.
|
||||
SmallVector<Value, 4> pack_args;
|
||||
RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
|
||||
Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
|
||||
Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
|
||||
for (int i = 0; i < counter_size; ++i) {
|
||||
Value word = counter[i];
|
||||
Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
|
||||
Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
|
||||
Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
|
||||
pack_args.push_back(new_word);
|
||||
|
||||
Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
|
||||
increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
|
||||
}
|
||||
|
||||
// Save the new state value to the resource.
|
||||
pack_args.push_back(key);
|
||||
Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
|
||||
rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
|
||||
|
||||
// Pad the original state as necessary to fill the output shape.
|
||||
int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
|
||||
Type i64 = rewriter.getI64Type();
|
||||
RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
|
||||
std::vector<int64_t> paddings_values = {0, pad};
|
||||
Value paddings = rewriter.create<ConstOp>(
|
||||
loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
|
||||
Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
|
||||
|
||||
rewriter.replaceOp(op, output);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
|
||||
} // namespace
|
||||
|
||||
void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<DecomposeRngReadAndSkipOp>(context);
|
||||
populateWithGenerated(context, *patterns);
|
||||
}
|
||||
|
||||
|
@ -241,6 +241,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::StatelessRandomNormalV2Op>(),
|
||||
TypeID::get<TF::StatelessRandomUniformOp>(),
|
||||
TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
|
||||
TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
|
||||
TypeID::get<TF::StatelessRandomUniformV2Op>(),
|
||||
TypeID::get<TF::StatelessRandomUniformIntOp>(),
|
||||
TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
|
||||
|
@ -1277,6 +1277,7 @@ tf_xla_py_test(
|
||||
name = "stateful_random_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["stateful_random_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.kernel_tests.random import util as \
|
||||
random_test_util
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
@ -76,6 +77,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
gen.uniform_full_int(shape=(3,))
|
||||
|
||||
@parameterized.parameters(ALGS)
|
||||
@test_util.disable_mlir_bridge("TODO(b/180412889): Crashes with MLIR bridge.")
|
||||
def testDefun(self, alg):
|
||||
"""Test for defun."""
|
||||
with ops.device(xla_device_name()):
|
||||
@ -248,6 +250,36 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
shape=shape, dtype=dtype))
|
||||
self.assertAllEqual(cpu, xla)
|
||||
|
||||
def testXLAEqualsCPUAroundCounterOverflow(self):
|
||||
"""Tests XLA and CPU kernels generate the same integers in overflow case.
|
||||
|
||||
Specifically this tests the case where the counter is incremented past
|
||||
what can fit within 64 bits of the 128 bit Philox counter.
|
||||
"""
|
||||
dtype = dtypes.uint64
|
||||
seed = 2**64 - 10
|
||||
shape = [315, 49]
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu_gen = random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
with ops.device(xla_device_name()):
|
||||
xla_gen = random.Generator.from_seed(
|
||||
seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
# Repeat multiple times to make sure that the state after
|
||||
# number-generation are the same between CPU and XLA.
|
||||
for _ in range(5):
|
||||
with ops.device("/device:CPU:0"):
|
||||
# Test both number-generation and skip
|
||||
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
cpu_gen.skip(100)
|
||||
with ops.device(xla_device_name()):
|
||||
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
xla_gen.skip(100)
|
||||
self.assertAllEqual(cpu, xla)
|
||||
self.assertAllEqual(cpu_gen.state, xla_gen.state)
|
||||
self.assertAllEqual(cpu, xla)
|
||||
|
||||
def _testRngIsNotConstant(self, rng, dtype):
|
||||
# Tests that 'rng' does not always return the same value.
|
||||
# The random-number generator, if working correctly, should produce the
|
||||
@ -352,6 +384,8 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
mean_atol=2e-3, median_atol=4e-3,
|
||||
variance_rtol=1e-2 if dtype == dtypes.bfloat16 else 5e-3)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"b/180412086: MLIR bridge gives wrong error messages.")
|
||||
def testErrors(self):
|
||||
"""Tests that proper errors are raised.
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user