[TF:XLA] Handle bitcasts between different bitwidths
PiperOrigin-RevId: 251810931
This commit is contained in:
parent
e3b76130b4
commit
3ed65f64e6
@ -12,16 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
|
||||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -112,23 +117,85 @@ class BitcastOp : public XlaOpKernel {
|
|||||||
|
|
||||||
if (src_dtype_ == dst_dtype_) {
|
if (src_dtype_ == dst_dtype_) {
|
||||||
output = input;
|
output = input;
|
||||||
} else {
|
ctx->SetOutput(0, output);
|
||||||
// Error out if the bitcast has a complex source or destination type and
|
return;
|
||||||
// the bitcast is not trivial.
|
}
|
||||||
OP_REQUIRES(ctx,
|
// Error out if the bitcast has a complex source or destination type and
|
||||||
!xla::primitive_util::IsComplexType(src_type_) &&
|
// the bitcast is not trivial.
|
||||||
!xla::primitive_util::IsComplexType(dst_type_),
|
OP_REQUIRES(ctx,
|
||||||
errors::Unimplemented("Complex types not supported."));
|
!xla::primitive_util::IsComplexType(src_type_) &&
|
||||||
// XLA bitcast requires that the bit-width of the source and destination
|
!xla::primitive_util::IsComplexType(dst_type_),
|
||||||
// matches, and currently only the simple lowering is performed.
|
errors::Unimplemented("Complex types not supported."));
|
||||||
OP_REQUIRES(ctx,
|
auto input_bit_width = xla::primitive_util::BitWidth(src_type_);
|
||||||
xla::primitive_util::BitWidth(src_type_) ==
|
auto output_bit_width = xla::primitive_util::BitWidth(dst_type_);
|
||||||
xla::primitive_util::BitWidth(dst_type_),
|
|
||||||
errors::Unimplemented(
|
auto input_logical_type =
|
||||||
"Only bitcasts between equally sized types supported."));
|
xla::primitive_util::UnsignedIntegralTypeForBitWidth(input_bit_width);
|
||||||
output = xla::BitcastConvertType(input, dst_type_);
|
auto output_logical_type =
|
||||||
|
xla::primitive_util::UnsignedIntegralTypeForBitWidth(output_bit_width);
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx,
|
||||||
|
|
||||||
|
output_bit_width % input_bit_width == 0 ||
|
||||||
|
input_bit_width % output_bit_width == 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Neither bit width is a multiple of the other."));
|
||||||
|
|
||||||
|
// Modify the input as needed so we only need to bitcast to create the
|
||||||
|
// output.
|
||||||
|
if (input_bit_width > output_bit_width) {
|
||||||
|
// Casting to a smaller bit width results in a new inner dimension.
|
||||||
|
auto broadcasted_input_shape = ctx->InputShape(0);
|
||||||
|
auto reshaped_input_shape = ctx->InputShape(0);
|
||||||
|
broadcasted_input_shape.AddDim(input_bit_width / output_bit_width);
|
||||||
|
reshaped_input_shape.AddDim(1);
|
||||||
|
auto output_bit_width_mask = (1 << output_bit_width) - 1;
|
||||||
|
|
||||||
|
auto status_or_input =
|
||||||
|
BroadcastTo(xla::Reshape(input, reshaped_input_shape.dim_sizes()),
|
||||||
|
broadcasted_input_shape.dim_sizes());
|
||||||
|
OP_REQUIRES_OK(ctx, status_or_input.status());
|
||||||
|
input = xla::BitcastConvertType(status_or_input.ConsumeValueOrDie(),
|
||||||
|
input_logical_type);
|
||||||
|
auto xla_input_shape_status = ctx->builder()->GetShape(input);
|
||||||
|
OP_REQUIRES_OK(ctx, xla_input_shape_status.status());
|
||||||
|
auto xla_input_shape = xla_input_shape_status.ConsumeValueOrDie();
|
||||||
|
|
||||||
|
auto iota = xla::Iota(ctx->builder(), xla_input_shape,
|
||||||
|
xla_input_shape.dimensions_size() - 1);
|
||||||
|
xla::XlaOp iota_m =
|
||||||
|
xla::Mul(xla::ScalarLike(input, output_bit_width), iota);
|
||||||
|
input = xla::And(xla::ShiftRightLogical(input, iota_m),
|
||||||
|
xla::ScalarLike(input, output_bit_width_mask));
|
||||||
|
input = xla::ConvertElementType(input, output_logical_type);
|
||||||
|
} else if (input_bit_width < output_bit_width) {
|
||||||
|
// Casting to a larger bit width results in removing the innermost
|
||||||
|
// dimension.
|
||||||
|
auto input_shape = ctx->InputShape(0);
|
||||||
|
xla::Shape input_xla_shape =
|
||||||
|
TensorShapeToXLAShape(dst_type_, input_shape);
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
input_shape.dim_size(input_shape.dims() - 1) ==
|
||||||
|
output_bit_width / input_bit_width,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Inner dimension of operand should be removed after cast."));
|
||||||
|
|
||||||
|
auto zero = XlaHelpers::Zero(ctx->builder(), dst_dtype_);
|
||||||
|
input = xla::ConvertElementType(input, dst_type_);
|
||||||
|
|
||||||
|
// Shift bits and OR them together to reduce the inner dimension.
|
||||||
|
xla::XlaOp iota_m =
|
||||||
|
xla::Mul(xla::ScalarLike(input, input_bit_width),
|
||||||
|
xla::Iota(ctx->builder(), input_xla_shape,
|
||||||
|
input_xla_shape.dimensions_size() - 1));
|
||||||
|
input = xla::ShiftLeft(input, iota_m);
|
||||||
|
input = xla::Reduce(input, zero,
|
||||||
|
CreateScalarOrComputation(dst_type_, ctx->builder()),
|
||||||
|
{input_xla_shape.dimensions_size() - 1});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
output = xla::BitcastConvertType(input, dst_type_);
|
||||||
ctx->SetOutput(0, output);
|
ctx->SetOutput(0, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,14 +38,12 @@ class BitcastTest(test.TestCase):
|
|||||||
self.assertEqual(tf_ans.get_shape(), shape)
|
self.assertEqual(tf_ans.get_shape(), shape)
|
||||||
self.assertEqual(tf_ans.dtype, datatype)
|
self.assertEqual(tf_ans.dtype, datatype)
|
||||||
|
|
||||||
@test_util.disable_xla("Different bitwidths not supported")
|
|
||||||
def testSmaller(self):
|
def testSmaller(self):
|
||||||
x = np.random.rand(3, 2)
|
x = np.random.rand(3, 2)
|
||||||
datatype = dtypes.int8
|
datatype = dtypes.int8
|
||||||
shape = [3, 2, 8]
|
shape = [3, 2, 8]
|
||||||
self._testBitcast(x, datatype, shape)
|
self._testBitcast(x, datatype, shape)
|
||||||
|
|
||||||
@test_util.disable_xla("Different bitwidths not supported")
|
|
||||||
def testLarger(self):
|
def testLarger(self):
|
||||||
x = np.arange(16, dtype=np.int8).reshape([4, 4])
|
x = np.arange(16, dtype=np.int8).reshape([4, 4])
|
||||||
datatype = dtypes.int32
|
datatype = dtypes.int32
|
||||||
@ -69,7 +67,6 @@ class BitcastTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, "Cannot bitcast due to shape"):
|
with self.assertRaisesRegexp(ValueError, "Cannot bitcast due to shape"):
|
||||||
array_ops.bitcast(x, datatype, None)
|
array_ops.bitcast(x, datatype, None)
|
||||||
|
|
||||||
@test_util.disable_xla("Different bitwidths not supported")
|
|
||||||
def testEmpty(self):
|
def testEmpty(self):
|
||||||
x = np.ones([], np.int32)
|
x = np.ones([], np.int32)
|
||||||
datatype = dtypes.int8
|
datatype = dtypes.int8
|
||||||
|
Loading…
Reference in New Issue
Block a user