[TF:XLA] Handle bitcasts between different bitwidths
PiperOrigin-RevId: 251810931
This commit is contained in:
parent
e3b76130b4
commit
3ed65f64e6
@ -12,16 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -112,23 +117,85 @@ class BitcastOp : public XlaOpKernel {
|
||||
|
||||
if (src_dtype_ == dst_dtype_) {
|
||||
output = input;
|
||||
} else {
|
||||
ctx->SetOutput(0, output);
|
||||
return;
|
||||
}
|
||||
// Error out if the bitcast has a complex source or destination type and
|
||||
// the bitcast is not trivial.
|
||||
OP_REQUIRES(ctx,
|
||||
!xla::primitive_util::IsComplexType(src_type_) &&
|
||||
!xla::primitive_util::IsComplexType(dst_type_),
|
||||
errors::Unimplemented("Complex types not supported."));
|
||||
// XLA bitcast requires that the bit-width of the source and destination
|
||||
// matches, and currently only the simple lowering is performed.
|
||||
auto input_bit_width = xla::primitive_util::BitWidth(src_type_);
|
||||
auto output_bit_width = xla::primitive_util::BitWidth(dst_type_);
|
||||
|
||||
auto input_logical_type =
|
||||
xla::primitive_util::UnsignedIntegralTypeForBitWidth(input_bit_width);
|
||||
auto output_logical_type =
|
||||
xla::primitive_util::UnsignedIntegralTypeForBitWidth(output_bit_width);
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
xla::primitive_util::BitWidth(src_type_) ==
|
||||
xla::primitive_util::BitWidth(dst_type_),
|
||||
errors::Unimplemented(
|
||||
"Only bitcasts between equally sized types supported."));
|
||||
output = xla::BitcastConvertType(input, dst_type_);
|
||||
|
||||
output_bit_width % input_bit_width == 0 ||
|
||||
input_bit_width % output_bit_width == 0,
|
||||
errors::InvalidArgument(
|
||||
"Neither bit width is a multiple of the other."));
|
||||
|
||||
// Modify the input as needed so we only need to bitcast to create the
|
||||
// output.
|
||||
if (input_bit_width > output_bit_width) {
|
||||
// Casting to a smaller bit width results in a new inner dimension.
|
||||
auto broadcasted_input_shape = ctx->InputShape(0);
|
||||
auto reshaped_input_shape = ctx->InputShape(0);
|
||||
broadcasted_input_shape.AddDim(input_bit_width / output_bit_width);
|
||||
reshaped_input_shape.AddDim(1);
|
||||
auto output_bit_width_mask = (1 << output_bit_width) - 1;
|
||||
|
||||
auto status_or_input =
|
||||
BroadcastTo(xla::Reshape(input, reshaped_input_shape.dim_sizes()),
|
||||
broadcasted_input_shape.dim_sizes());
|
||||
OP_REQUIRES_OK(ctx, status_or_input.status());
|
||||
input = xla::BitcastConvertType(status_or_input.ConsumeValueOrDie(),
|
||||
input_logical_type);
|
||||
auto xla_input_shape_status = ctx->builder()->GetShape(input);
|
||||
OP_REQUIRES_OK(ctx, xla_input_shape_status.status());
|
||||
auto xla_input_shape = xla_input_shape_status.ConsumeValueOrDie();
|
||||
|
||||
auto iota = xla::Iota(ctx->builder(), xla_input_shape,
|
||||
xla_input_shape.dimensions_size() - 1);
|
||||
xla::XlaOp iota_m =
|
||||
xla::Mul(xla::ScalarLike(input, output_bit_width), iota);
|
||||
input = xla::And(xla::ShiftRightLogical(input, iota_m),
|
||||
xla::ScalarLike(input, output_bit_width_mask));
|
||||
input = xla::ConvertElementType(input, output_logical_type);
|
||||
} else if (input_bit_width < output_bit_width) {
|
||||
// Casting to a larger bit width results in removing the innermost
|
||||
// dimension.
|
||||
auto input_shape = ctx->InputShape(0);
|
||||
xla::Shape input_xla_shape =
|
||||
TensorShapeToXLAShape(dst_type_, input_shape);
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
input_shape.dim_size(input_shape.dims() - 1) ==
|
||||
output_bit_width / input_bit_width,
|
||||
errors::InvalidArgument(
|
||||
"Inner dimension of operand should be removed after cast."));
|
||||
|
||||
auto zero = XlaHelpers::Zero(ctx->builder(), dst_dtype_);
|
||||
input = xla::ConvertElementType(input, dst_type_);
|
||||
|
||||
// Shift bits and OR them together to reduce the inner dimension.
|
||||
xla::XlaOp iota_m =
|
||||
xla::Mul(xla::ScalarLike(input, input_bit_width),
|
||||
xla::Iota(ctx->builder(), input_xla_shape,
|
||||
input_xla_shape.dimensions_size() - 1));
|
||||
input = xla::ShiftLeft(input, iota_m);
|
||||
input = xla::Reduce(input, zero,
|
||||
CreateScalarOrComputation(dst_type_, ctx->builder()),
|
||||
{input_xla_shape.dimensions_size() - 1});
|
||||
}
|
||||
|
||||
output = xla::BitcastConvertType(input, dst_type_);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
|
@ -38,14 +38,12 @@ class BitcastTest(test.TestCase):
|
||||
self.assertEqual(tf_ans.get_shape(), shape)
|
||||
self.assertEqual(tf_ans.dtype, datatype)
|
||||
|
||||
@test_util.disable_xla("Different bitwidths not supported")
|
||||
def testSmaller(self):
|
||||
x = np.random.rand(3, 2)
|
||||
datatype = dtypes.int8
|
||||
shape = [3, 2, 8]
|
||||
self._testBitcast(x, datatype, shape)
|
||||
|
||||
@test_util.disable_xla("Different bitwidths not supported")
|
||||
def testLarger(self):
|
||||
x = np.arange(16, dtype=np.int8).reshape([4, 4])
|
||||
datatype = dtypes.int32
|
||||
@ -69,7 +67,6 @@ class BitcastTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "Cannot bitcast due to shape"):
|
||||
array_ops.bitcast(x, datatype, None)
|
||||
|
||||
@test_util.disable_xla("Different bitwidths not supported")
|
||||
def testEmpty(self):
|
||||
x = np.ones([], np.int32)
|
||||
datatype = dtypes.int8
|
||||
|
Loading…
Reference in New Issue
Block a user