diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index 42df840ec63..5db52781be4 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -29,7 +29,6 @@ limitations under the License.
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/math/math_util.h"
 
 namespace tensorflow {
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index cc7390c6e60..fc74ef4aa34 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -308,6 +308,7 @@ cc_library(
         ":util",
         ":xla_data_proto",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
@@ -330,6 +331,7 @@ tf_cc_test(
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
     ],
@@ -373,6 +375,7 @@ cc_library(
         ":literal_util",
         ":util",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
     ],
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index a18c94c4e69..f833ddcd323 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -169,6 +169,7 @@ cc_library(
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/client:xla_builder",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/base",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc
index e71d899e323..c6f68c8ee2f 100644
--- a/tensorflow/compiler/xla/client/lib/prng.cc
+++ b/tensorflow/compiler/xla/client/lib/prng.cc
@@ -15,12 +15,12 @@ limitations under the License.
 
 #include <cmath>
 
+#include "absl/base/casts.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/lib/math.h"
 #include "tensorflow/compiler/xla/client/lib/numeric.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/casts.h"
 
 namespace xla {
 namespace {
@@ -149,7 +149,7 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
   constexpr int kMantissaBits = 23;
   bits = ShiftRightLogical(
              bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
-         ConstantR0<uint32>(builder, tensorflow::bit_cast<uint32>(1.0f));
+         ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
   auto floats = BitcastConvertType(bits, F32);
 
   // We have a floating point number in the range [1.0, 2.0).
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 8a8f49ccd04..510aa39b450 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -22,6 +22,7 @@ limitations under the License.
 #include <numeric>
 #include <vector>
 
+#include "absl/base/casts.h"
 #include "absl/memory/memory.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
@@ -31,7 +32,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/platform/logging.h"
@@ -1233,7 +1233,7 @@ typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
                         Literal>::type
 BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
   auto converter = [](NativeSrcT src) {
-    return tensorflow::bit_cast<NativeDestT>(src);
+    return absl::bit_cast<NativeDestT>(src);
   };
   return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
       src_literal, converter);
@@ -1995,7 +1995,7 @@ string LiteralBase::GetR1U8AsString() const {
   CHECK(ShapeUtil::IsArray(shape()));
   CHECK_EQ(ShapeUtil::Rank(shape()), 1);
   CHECK_EQ(shape().element_type(), U8);
-  return string(tensorflow::bit_cast<const char*>(data<uint8>().data()),
+  return string(absl::bit_cast<const char*>(data<uint8>().data()),
                 ShapeUtil::ElementsIn(shape()));
 }
 
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index 3d8725ed705..8d4b974c166 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -19,11 +19,11 @@ limitations under the License.
 #include <cmath>
 #include <vector>
 
+#include "absl/base/casts.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/platform/env.h"
 
 using absl::StrAppend;
@@ -40,8 +40,10 @@ namespace {
 template <typename FloatT, typename UnsignedT>
 Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs,
                                  absl::Span<const int64> multi_index) {
-  auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
-  auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
+  // TODO(b/118627822): These are unsafe bit_casts because Eigen::Half is not
+  // trivially copyable.
+  auto ulhs = absl::bit_cast<UnsignedT>(lhs);
+  auto urhs = absl::bit_cast<UnsignedT>(rhs);
   auto lhs_double = static_cast<double>(lhs);
   auto rhs_double = static_cast<double>(rhs);
   if (ulhs != urhs) {
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index dd5b54e4c99..4ae5ddbfdb8 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
 
 #include <vector>
 
+#include "absl/base/casts.h"
 #include "absl/memory/memory.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
@@ -28,7 +29,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/test.h"
 #include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
@@ -1312,11 +1312,10 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
 
 TEST_F(LiteralUtilTest, BitcastConvert) {
   auto original = LiteralUtil::CreateR1<uint32>(
-      {tensorflow::bit_cast<uint32>(2.5f),
-       tensorflow::bit_cast<uint32>(-42.25f),
-       tensorflow::bit_cast<uint32>(100.f), 0xbeef});
+      {absl::bit_cast<uint32>(2.5f), absl::bit_cast<uint32>(-42.25f),
+       absl::bit_cast<uint32>(100.f), 0xbeef});
   auto expected = LiteralUtil::CreateR1<float>(
-      {2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
+      {2.5f, -42.25f, 100.0f, absl::bit_cast<float>(0xbeef)});
   TF_ASSERT_OK_AND_ASSIGN(Literal converted, original.BitcastConvert(F32));
 }
 
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 0cb1ae35f4a..bb5e5e61000 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -30,7 +30,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 3a716c385b2..017b11465d1 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -253,6 +253,7 @@ cc_library(
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/core:lib",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/container:inlined_vector",
         "@com_google_absl//absl/container:node_hash_map",
         "@com_google_absl//absl/memory",
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 58abb330a6e..36e25cbe678 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -51,6 +51,7 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:stream_executor_no_cuda",
         "//tensorflow/stream_executor",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/memory",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 1cc28444703..1457582ac19 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
+#include "absl/base/casts.h"
 #include "absl/memory/memory.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/literal_util.h"
@@ -29,7 +30,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/platform/logging.h"
@@ -183,7 +183,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
     // Note: OSS build didn't like implicit conversion from
     // literal_shape.dimensions() to the array slice on 2017-07-10.
     absl::Span<const int64> dimensions(
-        tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
+        absl::bit_cast<const int64*>(literal_shape.dimensions().data()),
         literal_shape.dimensions().size());
     TF_ASSIGN_OR_RETURN(
         Shape received_shape,
diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
index cef5e57b0b1..f9722ffadac 100644
--- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
@@ -22,7 +22,6 @@ limitations under the License.
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/platform/logging.h"
 
 namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index c2998883851..7fcafafc097 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -43,7 +43,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/compiler/xla/window_util.h"
 #include "tensorflow/core/lib/core/bitmap.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 84fbbd3e0c3..ebed875eb49 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -19,6 +19,7 @@ limitations under the License.
 #include <cmath>
 
 #include "absl/algorithm/container.h"
+#include "absl/base/casts.h"
 #include "absl/container/inlined_vector.h"
 #include "absl/memory/memory.h"
 #include "absl/types/optional.h"
@@ -27,7 +28,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
 #include "tensorflow/compiler/xla/service/shape_inference.h"
-#include "tensorflow/core/lib/core/casts.h"
 
 namespace xla {
 
@@ -2442,7 +2442,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
         parent_->evaluated_[reduce_precision],
         ElementWiseUnaryOp(reduce_precision, [reduce_precision](
                                                  ElementwiseT elem) {
-          uint32_t value_as_int = tensorflow::bit_cast<uint32_t>(elem);
+          uint32_t value_as_int = absl::bit_cast<uint32_t>(elem);
           const uint32_t mantissa_bits = reduce_precision->mantissa_bits();
           const uint32_t exponent_bits = reduce_precision->exponent_bits();
 
@@ -2515,7 +2515,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
             value_as_int = x_underflows ? x_signed_zero : value_as_int;
           }
 
-          float reduced_result = tensorflow::bit_cast<float>(value_as_int);
+          float reduced_result = absl::bit_cast<float>(value_as_int);
           if (std::isnan(elem)) {
             reduced_result = mantissa_bits > 0
                                  ? elem
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 5f7ad81d829..850501a4b5c 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -72,6 +72,7 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo_module_config",
         "//tensorflow/compiler/xla/service:name_uniquer",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@llvm//:core",
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index 1a53c026be3..2e5aebb74c2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -19,6 +19,7 @@ limitations under the License.
 #include <memory>
 #include <vector>
 
+#include "absl/base/casts.h"
 #include "absl/strings/match.h"
 #include "absl/strings/str_cat.h"
 #include "llvm/IR/DerivedTypes.h"
@@ -33,7 +34,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/platform/byte_order.h"
@@ -362,11 +362,10 @@ static void LogS64(const char* tag, int64 value) {
 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b) {
   llvm::FunctionType* log_function_type = llvm::FunctionType::get(
       b->getVoidTy(), {b->getInt64Ty(), b->getInt64Ty()}, /*isVarArg=*/false);
-  b->CreateCall(
-      log_function_type,
-      b->CreateIntToPtr(b->getInt64(tensorflow::bit_cast<int64>(&LogS64)),
-                        log_function_type->getPointerTo()),
-      {b->getInt64(tensorflow::bit_cast<int64>(tag)), value});
+  b->CreateCall(log_function_type,
+                b->CreateIntToPtr(b->getInt64(absl::bit_cast<int64>(&LogS64)),
+                                  log_function_type->getPointerTo()),
+                {b->getInt64(absl::bit_cast<int64>(tag)), value});
 }
 
 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 3ad6960b4ea..5c6183984ff 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -660,6 +660,7 @@ xla_test(
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/types:span",
     ],
 )
@@ -684,6 +685,7 @@ xla_test(
         "//tensorflow/compiler/xla/client:xla_builder",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/base",
     ],
 )
 
@@ -707,6 +709,7 @@ xla_test(
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/strings",
     ],
 )
@@ -1616,6 +1619,7 @@ xla_test(
         "//tensorflow/core:stream_executor_no_cuda",
         "//tensorflow/core:test",
         "@com_google_absl//absl/algorithm:container",
+        "@com_google_absl//absl/base",
     ],
 )
 
@@ -1860,6 +1864,7 @@ xla_test(
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/types:span",
     ],
 )
@@ -2152,6 +2157,7 @@ xla_test(
         "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/container:flat_hash_set",
     ],
 )
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index c257566fb21..c131bfd6a6e 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
 #include <numeric>
 #include <vector>
 
+#include "absl/base/casts.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/array2d.h"
 #include "tensorflow/compiler/xla/array3d.h"
@@ -35,7 +36,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace xla {
@@ -139,7 +139,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
 }
 
 // A non-canonical quiet NaN value.
-static const float kNonCanonicalNaN = tensorflow::bit_cast<float>(0x7FD01234);
+static const float kNonCanonicalNaN = absl::bit_cast<float>(0x7FD01234);
 
 XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
   XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 5f063e67847..20bf3c31798 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include <vector>
 
 #include "absl/algorithm/container.h"
+#include "absl/base/casts.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -27,7 +28,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/math/math_util.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
 #include "tensorflow/core/platform/test.h"
@@ -429,11 +429,9 @@ TEST_F(ConvertTest, ConvertReshape) {
 
 std::vector<float> GetInterestingF16ConversionTestCases() {
   float infinity = std::numeric_limits<float>::infinity();
-  float half_min_positive_normal =
-      tensorflow::bit_cast<float, uint32>(0x38800000);
-  float half_max_subnormal = tensorflow::bit_cast<float, uint32>(0x387fc000);
-  float half_min_positive_subnormal =
-      tensorflow::bit_cast<float, uint32>(0x33800000);
+  float half_min_positive_normal = absl::bit_cast<float, uint32>(0x38800000);
+  float half_max_subnormal = absl::bit_cast<float, uint32>(0x387fc000);
+  float half_min_positive_subnormal = absl::bit_cast<float, uint32>(0x33800000);
   float half_max = 65504.0f;
 
   std::vector<float> test_cases(
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index 51b50d456e4..c84973e17b2 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "absl/base/casts.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/core/casts.h"
 
 namespace xla {
 namespace {
@@ -47,7 +47,7 @@ class ExhaustiveF32ElementwiseOpTest
         // input to 0 under the assumption that the op is at least correct on 0.
         input_literal.Set({i - begin}, 0.0f);
       } else {
-        input_literal.Set({i - begin}, tensorflow::bit_cast<float, int>(i));
+        input_literal.Set({i - begin}, absl::bit_cast<float, int>(i));
       }
     }
 
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 193e6696925..f80d29b9de4 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
 #include <numeric>
 #include <vector>
 
+#include "absl/base/casts.h"
 #include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/xla/array2d.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
@@ -34,7 +35,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace xla {
@@ -216,14 +216,13 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
   const uint32_t sign_bit = 1u << 31;
   for (const auto& test_value : test_values) {
     // Add positive values.
-    input_values.push_back(tensorflow::bit_cast<float>(test_value[0]));
-    expected_values.push_back(tensorflow::bit_cast<float>(test_value[index]));
+    input_values.push_back(absl::bit_cast<float>(test_value[0]));
+    expected_values.push_back(absl::bit_cast<float>(test_value[index]));
     // Add negative values.  We do this in the bitwise representation so as to
     // avoid problems with NaN handling.
-    input_values.push_back(
-        tensorflow::bit_cast<float>(test_value[0] ^ sign_bit));
+    input_values.push_back(absl::bit_cast<float>(test_value[0] ^ sign_bit));
     expected_values.push_back(
-        tensorflow::bit_cast<float>(test_value[index] ^ sign_bit));
+        absl::bit_cast<float>(test_value[index] ^ sign_bit));
   }
 
   // This is required for proper handling of NaN values.
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index 091a5d2cacc..606a099ecbc 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include <memory>
 
+#include "absl/base/casts.h"
 #include "absl/types/span.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
@@ -27,7 +28,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/compiler/xla/tests/test_macros.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
@@ -47,7 +47,7 @@ class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
 
 TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
   string data(sizeof(float) * 2, 0);
-  absl::Span<float> floats(tensorflow::bit_cast<float*>(data.data()), 2);
+  absl::Span<float> floats(absl::bit_cast<float*>(data.data()), 2);
   floats[0] = 42.0;
   floats[1] = 24.0;
 
@@ -69,7 +69,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
 
 TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
   string data(sizeof(float) * 4, 0);
-  absl::Span<float> floats(tensorflow::bit_cast<float*>(data.data()), 4);
+  absl::Span<float> floats(absl::bit_cast<float*>(data.data()), 4);
   // With x as the minor dimension, these will become:
   floats[0] = 42.0;  // y=0,x=0
   floats[1] = 24.0;  // y=0,x=1
@@ -102,7 +102,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
 
 TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
   string data(sizeof(float) * 4, 0);
-  absl::Span<float> floats(tensorflow::bit_cast<float*>(data.data()), 4);
+  absl::Span<float> floats(absl::bit_cast<float*>(data.data()), 4);
   // With y as the minor dimension, these will become:
   floats[0] = 42.0;  // y=0,x=0
   floats[1] = 24.0;  // y=1,x=0
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index bc433eac8fc..e066b3f4f22 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -15,13 +15,13 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/tests/test_utils.h"
 
+#include "absl/base/casts.h"
 #include "absl/container/flat_hash_set.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/service/hlo_parser.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
 #include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 
 namespace xla {
@@ -148,7 +148,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
 
   absl::flat_hash_set<uint32> key_set;
   for (const float& value : key_arg.data<float>()) {
-    EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
+    EXPECT_TRUE(key_set.insert(absl::bit_cast<uint32>(value)).second);
   }
 }
 
@@ -171,7 +171,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
 
   absl::flat_hash_set<int32> key_set;
   for (const int32& value : key_arg.data<int32>()) {
-    EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
+    EXPECT_TRUE(key_set.insert(absl::bit_cast<uint32>(value)).second);
   }
 }
 
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 1a2dd263b06..4628258efc4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -393,7 +393,6 @@ cc_library(
         ":lib_platform",
         ":platform_base",
         "//tensorflow/core/platform/default/build_config:port",
-        "@com_google_absl//absl/base",
         "@snappy",
     ],
 )
@@ -657,7 +656,6 @@ cc_library(
         "lib/core/arena.h",
         "lib/core/bitmap.h",
         "lib/core/bits.h",
-        "lib/core/casts.h",
         "lib/core/coding.h",
         "lib/core/errors.h",
         "lib/core/notification.h",
@@ -2287,7 +2285,6 @@ cc_library(
     srcs = ["lib/png/png_io.cc"],
     hdrs = [
         "lib/bfloat16/bfloat16.h",
-        "lib/core/casts.h",
         "lib/core/stringpiece.h",
         "lib/png/png_io.h",
         "platform/byte_order.h",
@@ -2310,6 +2307,7 @@ cc_library(
         ":lib",
         ":lib_internal",
         "//tensorflow/core/platform/default/build_config:png",
+        "@com_google_absl//absl/base",
         "@com_google_absl//absl/strings",
         "@zlib_archive//:zlib",
     ],
@@ -2646,6 +2644,7 @@ tf_cuda_library(
         ":protos_all_cc",
         ":stats_calculator_portable",
         ":version_lib",
+        "@com_google_absl//absl/base",
         "//tensorflow/core/platform/default/build_config:platformlib",
         "//tensorflow/core/kernels:bounds_check",
         "//third_party/eigen3",
@@ -3328,7 +3327,6 @@ tf_cc_tests(
     size = "small",
     srcs = [
         "lib/core/arena_test.cc",
-        "lib/core/bit_cast_test.cc",
         "lib/core/bitmap_test.cc",
         "lib/core/blocking_counter_test.cc",
         "lib/core/coding_test.cc",
@@ -3534,6 +3532,7 @@ tf_cc_test(
         ":lib_internal",
         ":test",
         ":test_main",
+        "@com_google_absl//absl/base",
     ],
 )
 
@@ -3707,6 +3706,7 @@ tf_cc_tests(
         "//tensorflow/cc:while_loop",
         "//tensorflow/core/kernels:ops_util",
         "//third_party/eigen3",
+        "@com_google_absl//absl/base",
     ],
 )
 
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
index d71f92151d8..ce970854941 100644
--- a/tensorflow/core/framework/bfloat16_test.cc
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -15,8 +15,8 @@ limitations under the License.
 
 #include "tensorflow/core/framework/bfloat16.h"
 
+#include "absl/base/casts.h"
 #include "tensorflow/core/framework/numeric_types.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 
@@ -45,8 +45,8 @@ TEST(Bfloat16Test, Simple) {
 
 float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
                     uint32_t low_mantissa) {
-  return bit_cast<float>((sign << 31) + (exponent << 23) +
-                         (high_mantissa << 16) + low_mantissa);
+  return absl::bit_cast<float>((sign << 31) + (exponent << 23) +
+                               (high_mantissa << 16) + low_mantissa);
 }
 
 struct Bfloat16TestParam {
diff --git a/tensorflow/core/kernels/bitcast_op.cc b/tensorflow/core/kernels/bitcast_op.cc
index 90825e6d39a..f602cfa428a 100644
--- a/tensorflow/core/kernels/bitcast_op.cc
+++ b/tensorflow/core/kernels/bitcast_op.cc
@@ -19,7 +19,6 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/core/casts.h"
 
 namespace tensorflow {
 
diff --git a/tensorflow/core/kernels/bitcast_op.h b/tensorflow/core/kernels/bitcast_op.h
index 900ab6f35c1..1f3659f3033 100644
--- a/tensorflow/core/kernels/bitcast_op.h
+++ b/tensorflow/core/kernels/bitcast_op.h
@@ -25,6 +25,5 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/core/casts.h"
 
 #endif  // TENSORFLOW_CORE_KERNELS_BITCAST_OP_H_
diff --git a/tensorflow/core/kernels/hexagon/BUILD b/tensorflow/core/kernels/hexagon/BUILD
index 4870d9ae200..87d36f22d71 100644
--- a/tensorflow/core/kernels/hexagon/BUILD
+++ b/tensorflow/core/kernels/hexagon/BUILD
@@ -40,6 +40,7 @@ tf_cc_test(
         "//tensorflow/core/kernels:remote_fused_graph_ops",
         "//tensorflow/core/kernels:reshape_op",
         "//tensorflow/core/kernels:softmax_op",
+        "@com_google_absl//absl/base",
     ],
 )
 
diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
index d53977703e4..690d13c4e65 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
@@ -29,6 +29,7 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
 
 #include <memory>
 
+#include "absl/base/casts.h"
 #include "tensorflow/core/framework/graph_transfer_info.pb.h"
 #include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
@@ -40,7 +41,6 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
 #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
 #include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
 #include "tensorflow/core/kernels/quantization_utils.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/io/path.h"
@@ -132,7 +132,7 @@ static void LoadImage(std::vector<float>* img_floats_ptr) {
   const int64 pixel_count = WIDTH * HEIGHT * DEPTH;
   CHECK(fsize >= 22 /* pos of height */ + sizeof(int));
   CHECK(bmp.data() != nullptr);
-  uint8* const img_bytes = bit_cast<uint8*>(bmp.data());
+  uint8* const img_bytes = absl::bit_cast<uint8*>(bmp.data());
   const int header_size = *(reinterpret_cast<int*>(img_bytes + 10));
   LOG(INFO) << "header size = " << header_size;
   const int size = *(reinterpret_cast<int*>(img_bytes + 14));
diff --git a/tensorflow/core/kernels/quantized_add_op.cc b/tensorflow/core/kernels/quantized_add_op.cc
index 337c8e5c178..55c69de7d3e 100644
--- a/tensorflow/core/kernels/quantized_add_op.cc
+++ b/tensorflow/core/kernels/quantized_add_op.cc
@@ -27,7 +27,6 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/kernels/meta_support.h"
 #include "tensorflow/core/kernels/quantization_utils.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/util/bcast.h"
 
diff --git a/tensorflow/core/kernels/quantized_mul_op.cc b/tensorflow/core/kernels/quantized_mul_op.cc
index 3c7536e0373..4e191f16266 100644
--- a/tensorflow/core/kernels/quantized_mul_op.cc
+++ b/tensorflow/core/kernels/quantized_mul_op.cc
@@ -26,7 +26,6 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/kernels/meta_support.h"
 #include "tensorflow/core/kernels/quantization_utils.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/util/bcast.h"
 
diff --git a/tensorflow/core/lib/core/bit_cast_test.cc b/tensorflow/core/lib/core/bit_cast_test.cc
deleted file mode 100644
index f68b2c40531..00000000000
--- a/tensorflow/core/lib/core/bit_cast_test.cc
+++ /dev/null
@@ -1,111 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Unit test for bit_cast template.
-
-#include "tensorflow/core/lib/core/casts.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-
-// Marshall and unmarshall.
-// ISO spec C++ section 3.9 promises this will work.
-
-template <int N>
-struct marshall {
-  char buf[N];
-};
-
-template <class T>
-void TestMarshall(const T values[], int num_values) {
-  for (int i = 0; i < num_values; ++i) {
-    T t0 = values[i];
-    marshall<sizeof(T)> m0 = bit_cast<marshall<sizeof(T)> >(t0);
-    T t1 = bit_cast<T>(m0);
-    marshall<sizeof(T)> m1 = bit_cast<marshall<sizeof(T)> >(t1);
-    ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T)));
-    ASSERT_EQ(0, memcmp(&m0, &m1, sizeof(T)));
-  }
-}
-
-// Convert back and forth to an integral type.  The C++ standard does
-// not guarantee this will work.
-//
-// There are implicit assumptions about sizeof(float) and
-// sizeof(double). These assumptions are quite extant everywhere.
-
-template <class T, class I>
-void TestIntegral(const T values[], int num_values) {
-  for (int i = 0; i < num_values; ++i) {
-    T t0 = values[i];
-    I i0 = bit_cast<I>(t0);
-    T t1 = bit_cast<T>(i0);
-    I i1 = bit_cast<I>(t1);
-    ASSERT_EQ(0, memcmp(&t0, &t1, sizeof(T)));
-    ASSERT_EQ(i0, i1);
-  }
-}
-
-TEST(BitCast, Bool) {
-  LOG(INFO) << "Test bool";
-  static const bool bool_list[] = {false, true};
-  TestMarshall<bool>(bool_list, TF_ARRAYSIZE(bool_list));
-}
-
-TEST(BitCast, Int32) {
-  static const int32 int_list[] = {0,  1,    100,         2147483647,
-                                   -1, -100, -2147483647, -2147483647 - 1};
-  TestMarshall<int32>(int_list, TF_ARRAYSIZE(int_list));
-}
-
-TEST(BitCast, Int64) {
-  static const int64 int64_list[] = {0, 1, 1LL << 40, -1, -(1LL << 40)};
-  TestMarshall<int64>(int64_list, TF_ARRAYSIZE(int64_list));
-}
-
-TEST(BitCast, Uint64) {
-  static const uint64 uint64_list[] = {0, 1, 1LLU << 40, 1LLU << 63};
-  TestMarshall<uint64>(uint64_list, TF_ARRAYSIZE(uint64_list));
-}
-
-TEST(BitCast, Float) {
-  static const float float_list[] = {0.0,  1.0,   -1.0,  10.0,    -10.0,  1e10,
-                                     1e20, 1e-10, 1e-20, 2.71828, 3.14159};
-  TestMarshall<float>(float_list, TF_ARRAYSIZE(float_list));
-  TestIntegral<float, int32>(float_list, TF_ARRAYSIZE(float_list));
-  TestIntegral<float, uint32>(float_list, TF_ARRAYSIZE(float_list));
-}
-
-TEST(BitCast, Double) {
-  static const double double_list[] = {
-      0.0,
-      1.0,
-      -1.0,
-      10.0,
-      -10.0,
-      1e10,
-      1e100,
-      1e-10,
-      1e-100,
-      2.718281828459045,
-      3.141592653589793238462643383279502884197169399375105820974944};
-  TestMarshall<double>(double_list, TF_ARRAYSIZE(double_list));
-  TestIntegral<double, int64>(double_list, TF_ARRAYSIZE(double_list));
-  TestIntegral<double, uint64>(double_list, TF_ARRAYSIZE(double_list));
-}
-
-}  // namespace tensorflow
diff --git a/tensorflow/core/lib/core/casts.h b/tensorflow/core/lib/core/casts.h
deleted file mode 100644
index 7546d4edc5a..00000000000
--- a/tensorflow/core/lib/core/casts.h
+++ /dev/null
@@ -1,100 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Various Google-specific casting templates.
-//
-// This code is compiled directly on many platforms, including client
-// platforms like Windows, Mac, and embedded systems.  Before making
-// any changes here, make sure that you're not breaking any platforms.
-//
-
-#ifndef TENSORFLOW_CORE_LIB_CORE_CASTS_H_
-#define TENSORFLOW_CORE_LIB_CORE_CASTS_H_
-
-#include <string.h>  // for memcpy
-
-namespace tensorflow {
-
-// bit_cast<Dest,Source> is a template function that implements the
-// equivalent of "*reinterpret_cast<Dest*>(&source)".  We need this in
-// very low-level functions like the protobuf library and fast math
-// support.
-//
-//   float f = 3.14159265358979;
-//   int i = bit_cast<int32>(f);
-//   // i = 0x40490fdb
-//
-// The classical address-casting method is:
-//
-//   // WRONG
-//   float f = 3.14159265358979;            // WRONG
-//   int i = * reinterpret_cast<int*>(&f);  // WRONG
-//
-// The address-casting method actually produces undefined behavior
-// according to ISO C++ specification section 3.10 -15 -.  Roughly, this
-// section says: if an object in memory has one type, and a program
-// accesses it with a different type, then the result is undefined
-// behavior for most values of "different type".
-//
-// This is true for any cast syntax, either *(int*)&f or
-// *reinterpret_cast<int*>(&f).  And it is particularly true for
-// conversions between integral lvalues and floating-point lvalues.
-//
-// The purpose of 3.10 -15- is to allow optimizing compilers to assume
-// that expressions with different types refer to different memory.  gcc
-// 4.0.1 has an optimizer that takes advantage of this.  So a
-// non-conforming program quietly produces wildly incorrect output.
-//
-// The problem is not the use of reinterpret_cast.  The problem is type
-// punning: holding an object in memory of one type and reading its bits
-// back using a different type.
-//
-// The C++ standard is more subtle and complex than this, but that
-// is the basic idea.
-//
-// Anyways ...
-//
-// bit_cast<> calls memcpy() which is blessed by the standard,
-// especially by the example in section 3.9 .  Also, of course,
-// bit_cast<> wraps up the nasty logic in one place.
-//
-// Fortunately memcpy() is very fast.  In optimized mode, with a
-// constant size, gcc 2.95.3, gcc 4.0.1, and msvc 7.1 produce inline
-// code with the minimal amount of data movement.  On a 32-bit system,
-// memcpy(d,s,4) compiles to one load and one store, and memcpy(d,s,8)
-// compiles to two loads and two stores.
-//
-// I tested this code with gcc 2.95.3, gcc 4.0.1, icc 8.1, and msvc 7.1.
-//
-// WARNING: if Dest or Source is a non-POD type, the result of the memcpy
-// is likely to surprise you.
-//
-// Props to Bill Gibbons for the compile time assertion technique and
-// Art Komninos and Igor Tandetnik for the msvc experiments.
-//
-// -- mec 2005-10-17
-
-template <class Dest, class Source>
-inline Dest bit_cast(const Source& source) {
-  static_assert(sizeof(Dest) == sizeof(Source), "Sizes do not match");
-
-  Dest dest;
-  memcpy(&dest, &source, sizeof(dest));
-  return dest;
-}
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_CORE_LIB_CORE_CASTS_H_
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc
index 15266af1dbd..62dd31a65f6 100644
--- a/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc
+++ b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc
@@ -22,18 +22,19 @@ limitations under the License.
 
 #include <memory>
 
+#include "absl/base/casts.h"
 #include "tensorflow/core/lib/jpeg/jpeg_handle.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/types.h"
 
-#include "tensorflow/core/lib/core/casts.h"
 
 namespace tensorflow {
 namespace jpeg {
 namespace {
 
+using absl::bit_cast;
 const char kTestData[] = "tensorflow/core/lib/jpeg/testdata/";
 
 int ComputeSumAbsoluteDifference(const uint8* a, const uint8* b, int width,
diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc
index e226a15ccca..bc52180265c 100644
--- a/tensorflow/core/lib/png/png_io.cc
+++ b/tensorflow/core/lib/png/png_io.cc
@@ -24,7 +24,7 @@ limitations under the License.
 // NOTE(skal): we don't '#include <setjmp.h>' before png.h as it otherwise
 // provokes a compile error. We instead let png.h include what is needed.
 
-#include "tensorflow/core/lib/core/casts.h"
+#include "absl/base/casts.h"
 #include "tensorflow/core/lib/png/png_io.h"
 #include "tensorflow/core/platform/byte_order.h"
 #include "tensorflow/core/platform/logging.h"
@@ -76,7 +76,8 @@ static void Convert8to16(const uint8* p8, int num_comps, int p8_row_bytes,
 #undef CPTR_INC
 
 void ErrorHandler(png_structp png_ptr, png_const_charp msg) {
-  DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
+  DecodeContext* const ctx =
+      absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
   ctx->error_condition = true;
   // To prevent log spam, errors are logged as VLOG(1) instead of ERROR.
   VLOG(1) << "PNG error: " << msg;
@@ -88,7 +89,8 @@ void WarningHandler(png_structp png_ptr, png_const_charp msg) {
 }
 
 void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) {
-  DecodeContext* const ctx = bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
+  DecodeContext* const ctx =
+      absl::bit_cast<DecodeContext*>(png_get_io_ptr(png_ptr));
   if (static_cast<png_size_t>(ctx->data_left) < length) {
     memset(data, 0, length);
     png_error(png_ptr, "More bytes requested to read than available");
@@ -100,8 +102,8 @@ void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) {
 }
 
 void StringWriter(png_structp png_ptr, png_bytep data, png_size_t length) {
-  string* const s = bit_cast<string*>(png_get_io_ptr(png_ptr));
-  s->append(bit_cast<const char*>(data), length);
+  string* const s = absl::bit_cast<string*>(png_get_io_ptr(png_ptr));
+  s->append(absl::bit_cast<const char*>(data), length);
 }
 
 void StringWriterFlush(png_structp png_ptr) {}
@@ -215,7 +217,7 @@ bool CommonInitDecode(StringPiece png_string, int desired_channels,
     CommonFreeDecode(context);
     return false;
   }
-  context->data = bit_cast<const uint8*>(png_string.data());
+  context->data = absl::bit_cast<const uint8*>(png_string.data());
   context->data_left = png_string.size();
   png_set_read_fn(context->png_ptr, context, StringReader);
   png_read_info(context->png_ptr, context->info_ptr);
@@ -328,8 +330,8 @@ bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
 
   // Synthesize 16 bits from 8 if requested.
   if (context->need_to_synthesize_16)
-    Convert8to16(bit_cast<uint8*>(data), context->channels, row_bytes,
-                 context->width, context->height, bit_cast<uint16*>(data),
+    Convert8to16(absl::bit_cast<uint8*>(data), context->channels, row_bytes,
+                 context->width, context->height, absl::bit_cast<uint16*>(data),
                  row_bytes);
   return ok;
 }
diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h
index c876c5156ab..d3a44b19eed 100644
--- a/tensorflow/core/lib/png/png_io.h
+++ b/tensorflow/core/lib/png/png_io.h
@@ -35,6 +35,7 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
+#include "absl/base/casts.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/platform/png.h"
 #include "tensorflow/core/platform/types.h"
@@ -68,7 +69,7 @@ bool DecodeHeader(StringPiece png_string, int* width, int* height,
 // DecodeContext context;
 // CHECK(CommonInitDecode(png_string, 3 /*RGB*/, 8 /*uint8*/, &context));
 // char* image_buffer = new char[3*context.width*context.height];
-// CHECK(CommonFinishDecode(bit_cast<png_byte*>(image_buffer),
+// CHECK(CommonFinishDecode(absl::bit_cast<png_byte*>(image_buffer),
 //       3*context.width /*stride*/, &context));
 //
 // desired_channels may be 0 to detected it from the input.
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc
index c536b5688ef..b4f0bfbfb96 100644
--- a/tensorflow/core/lib/wav/wav_io.cc
+++ b/tensorflow/core/lib/wav/wav_io.cc
@@ -19,7 +19,7 @@ limitations under the License.
 #include <string.h>
 #include <algorithm>
 
-#include "tensorflow/core/lib/core/casts.h"
+#include "absl/base/casts.h"
 #include "tensorflow/core/lib/core/coding.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/wav/wav_io.h"
@@ -174,7 +174,7 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
 
   wav_string->resize(file_size);
   char* data = &wav_string->at(0);
-  WavHeader* header = bit_cast<WavHeader*>(data);
+  WavHeader* header = absl::bit_cast<WavHeader*>(data);
 
   // Fill RIFF chunk.
   auto* riff_chunk = &header->riff_chunk;
diff --git a/tensorflow/core/lib/wav/wav_io.h b/tensorflow/core/lib/wav/wav_io.h
index f004524177e..9145e7c9f22 100644
--- a/tensorflow/core/lib/wav/wav_io.h
+++ b/tensorflow/core/lib/wav/wav_io.h
@@ -21,7 +21,6 @@ limitations under the License.
 #include <string>
 #include <vector>
 
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/coding.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index e52d55e2ffe..3cc75bbd1f3 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -16,13 +16,13 @@ limitations under the License.
 
 #include <vector>
 
+#include "absl/base/casts.h"
 #include "tensorflow/core/example/example.pb.h"
 #include "tensorflow/core/example/feature.pb_text.h"
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/lib/core/blocking_counter.h"
-#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -164,7 +164,7 @@ class Feature {
         while (!stream.ExpectAtEnd()) {
           uint32 buffer32;
           if (!stream.ReadLittleEndian32(&buffer32)) return false;
-          float_list->push_back(bit_cast<float>(buffer32));
+          float_list->push_back(absl::bit_cast<float>(buffer32));
         }
 
         stream.PopLimit(packed_limit);
@@ -173,7 +173,7 @@ class Feature {
           if (!stream.ExpectTag(kFixed32Tag(1))) return false;
           uint32 buffer32;
           if (!stream.ReadLittleEndian32(&buffer32)) return false;
-          float_list->push_back(bit_cast<float>(buffer32));
+          float_list->push_back(absl::bit_cast<float>(buffer32));
         }
       }
     }
@@ -1600,7 +1600,7 @@ inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
           return -1;
         }
         if (out != nullptr) {
-          *out++ = bit_cast<float>(buffer32);
+          *out++ = absl::bit_cast<float>(buffer32);
         }
         num_elements++;
       }
@@ -1613,7 +1613,7 @@ inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
           return -1;
         }
         if (out != nullptr) {
-          *out++ = bit_cast<float>(buffer32);
+          *out++ = absl::bit_cast<float>(buffer32);
         }
         num_elements++;
       }