diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1c9bddd1dbc..b32acbedcf1 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1974,6 +1974,7 @@ cc_library( "//tensorflow/core/platform:abi", "//tensorflow/core/platform:base64", "//tensorflow/core/platform:blocking_counter", + "//tensorflow/core/platform:casts", "//tensorflow/core/platform:coding", "//tensorflow/core/platform:context", "//tensorflow/core/platform:cord", diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 2c1dceb8f4e..fbf681ac329 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/refcount.h" diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index dd80ead1f86..ad5d1517176 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -124,20 +124,6 @@ class IteratorHandleOp : public OpKernel { // inconsistent capacities. Status VerifyResource(IteratorResource* resource); - template <typename To, typename From> // use like this: down_cast<T*>(foo); - static inline To down_cast(From* f) { // so we only accept pointers - static_assert( - (std::is_base_of<From, typename std::remove_pointer<To>::type>::value), - "target type not derived from source type"); - - // We skip the assert and hence the dynamic_cast if RTTI is disabled. -#if !defined(__GNUC__) || defined(__GXX_RTTI) - // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. - assert(f == nullptr || dynamic_cast<To>(f) != nullptr); -#endif // !defined(__GNUC__) || defined(__GXX_RTTI) - return static_cast<To>(f); - } - FunctionLibraryRuntime* CreatePrivateFLR( OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr, std::unique_ptr<FunctionLibraryDefinition>* flib_def, diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index d7d15d5f14b..ec749dfe9dd 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { @@ -41,21 +42,6 @@ Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func, return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle); } -template <typename To, typename From> // use like this: down_cast<T*>(foo); -inline To down_cast(From* f) { // so we only accept pointers - static_assert( - (std::is_base_of<From, typename std::remove_pointer<To>::type>::value), - "target type not derived from source type"); - - // We skip the assert and hence the dynamic_cast if RTTI is disabled. -#if !defined(__GNUC__) || defined(__GXX_RTTI) - // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. - assert(f == nullptr || dynamic_cast<To>(f) != nullptr); -#endif // !defined(__GNUC__) || defined(__GXX_RTTI) - - return static_cast<To>(f); -} - // If "t" is a scalar of a supported type, returns t != 0 in "*v". Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) { if (t.size() != 1) { diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index e44cfdf1ec7..80ca00388ff 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -73,6 +73,7 @@ limitations under the License. #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -104,7 +105,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) { } else if (ctx->op_device_context() != nullptr) { // TODO(apassos): remove the down_cast by just returning Device* from // OpKernelContext - Device* device = static_cast<Device*>(ctx->device()); + Device* device = down_cast<Device*>(ctx->device()); ctx->op_device_context()->CopyTensorInSameDevice( t, device, output, [&n, &status](const Status& s) { status = s; diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index bcfb935206e..5dfeeb89c43 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -158,6 +158,14 @@ cc_library( ], ) +cc_library( + name = "casts", + hdrs = ["casts.h"], + deps = [ + ":platform", + ] + tf_platform_deps("casts"), +) + cc_library( name = "cuda", hdrs = ["cuda.h"], @@ -1060,6 +1068,7 @@ filegroup( name = "lib_hdrs", srcs = [ "abi.h", + "casts.h", "context.h", "cpu_feature_guard.h", "cpu_info.h", @@ -1254,6 +1263,7 @@ filegroup( "//tensorflow/core/platform:base64.h", "//tensorflow/core/platform:blocking_counter.h", "//tensorflow/core/platform:byte_order.h", + "//tensorflow/core/platform:casts.h", "//tensorflow/core/platform:coding.cc", "//tensorflow/core/platform:coding.h", "//tensorflow/core/platform:context.h", @@ -1360,6 +1370,7 @@ filegroup( name = "legacy_srcs_no_runtime", srcs = [ ":legacy_srcs_common", + "//tensorflow/core/platform/default:casts.h", "//tensorflow/core/platform/default:context.h", "//tensorflow/core/platform/default:cord.h", "//tensorflow/core/platform/default:dynamic_annotations.h", diff --git a/tensorflow/core/platform/casts.h b/tensorflow/core/platform/casts.h new file mode 100644 index 00000000000..be7be00bd45 --- /dev/null +++ b/tensorflow/core/platform/casts.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CASTS_H_ +#define TENSORFLOW_CORE_PLATFORM_CASTS_H_ + +#include "tensorflow/core/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/core/platform/google/casts.h" +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ + defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \ + defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS) +#include "tensorflow/core/platform/default/casts.h" +#else +#error Define the appropriate PLATFORM_<foo> macro for this platform +#endif + +#endif // TENSORFLOW_CORE_PLATFORM_CASTS_H_ diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD index 491f84536cf..346018153d5 100644 --- a/tensorflow/core/platform/default/BUILD +++ b/tensorflow/core/platform/default/BUILD @@ -9,6 +9,16 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "casts", + hdrs = ["casts.h"], + tags = [ + "manual", + "no_oss", + "nobuilder", + ], +) + cc_library( name = "context", hdrs = ["//tensorflow/core/platform:context.h"], diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 61fe01cb262..3c0a4676eff 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -541,6 +541,7 @@ def tf_proto_library( def tf_additional_lib_hdrs(): return [ + "//tensorflow/core/platform/default:casts.h", "//tensorflow/core/platform/default:context.h", "//tensorflow/core/platform/default:cord.h", "//tensorflow/core/platform/default:dynamic_annotations.h", diff --git a/tensorflow/core/platform/default/casts.h b/tensorflow/core/platform/default/casts.h new file mode 100644 index 00000000000..ed1d2a66812 --- /dev/null +++ b/tensorflow/core/platform/default/casts.h @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_ +#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_ + +#include <assert.h> // for use with down_cast<> + +#include <type_traits> + +namespace tensorflow { + +// An "upcast", i.e. a conversion from a pointer to an object to a pointer to a +// base subobject, always succeeds if the base is unambiguous and accessible, +// and so it's fine to use implicit_cast. +// +// A "downcast", i.e. a conversion from a pointer to an object to a pointer +// to a more-derived object that may contain the original object as a base +// subobject, cannot safely be done using static_cast, because you do not +// generally know whether the source object is really the base subobject of +// a containing, more-derived object of the target type. Thus, when you +// downcast in a polymorphic type hierarchy, you should use the following +// function template. +// +// In debug mode, we use dynamic_cast to double-check whether the downcast is +// legal (we die if it's not). In normal mode, we do the efficient static_cast +// instead. Thus, it's important to test in debug mode to make sure the cast is +// legal! +// +// This is the only place in the codebase we should use dynamic_cast. +// In particular, you should NOT use dynamic_cast for RTTI, e.g. for +// code like this: +// if (auto* p = dynamic_cast<Subclass1*>(foo)) HandleASubclass1Object(p); +// if (auto* p = dynamic_cast<Subclass2*>(foo)) HandleASubclass2Object(p); +// You should design the code some other way not to need this. + +template <typename To, typename From> // use like this: down_cast<T*>(foo); +inline To down_cast(From* f) { // so we only accept pointers + static_assert( + (std::is_base_of<From, typename std::remove_pointer<To>::type>::value), + "target type not derived from source type"); + + // We skip the assert and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. + assert(f == nullptr || dynamic_cast<To>(f) != nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast<To>(f); +} + +// Overload of down_cast for references. Use like this: down_cast<T&>(foo). +// The code is slightly convoluted because we're still using the pointer +// form of dynamic cast. (The reference form throws an exception if it +// fails.) +// +// There's no need for a special const overload either for the pointer +// or the reference form. If you call down_cast with a const T&, the +// compiler will just bind From to const T. +template <typename To, typename From> +inline To down_cast(From& f) { + static_assert(std::is_lvalue_reference<To>::value, + "target type not a reference"); + static_assert( + (std::is_base_of<From, typename std::remove_reference<To>::type>::value), + "target type not derived from source type"); + + // We skip the assert and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + // RTTI: debug mode only + assert(dynamic_cast<typename std::remove_reference<To>::type*>(&f) != + nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast<To>(f); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_