Move down_cast to platform/casts.h
PiperOrigin-RevId: 289736444 Change-Id: I93321c9130243b15d789bd4ec63588d54adc011e
This commit is contained in:
parent
99eb226655
commit
c8dc5d6d53
@ -1974,6 +1974,7 @@ cc_library(
|
||||
"//tensorflow/core/platform:abi",
|
||||
"//tensorflow/core/platform:base64",
|
||||
"//tensorflow/core/platform:blocking_counter",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:coding",
|
||||
"//tensorflow/core/platform:context",
|
||||
"//tensorflow/core/platform:cord",
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
|
@ -124,20 +124,6 @@ class IteratorHandleOp : public OpKernel {
|
||||
// inconsistent capacities.
|
||||
Status VerifyResource(IteratorResource* resource);
|
||||
|
||||
template <typename To, typename From> // use like this: down_cast<T*>(foo);
|
||||
static inline To down_cast(From* f) { // so we only accept pointers
|
||||
static_assert(
|
||||
(std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
|
||||
"target type not derived from source type");
|
||||
|
||||
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
|
||||
#if !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||
// Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
|
||||
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
|
||||
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||
return static_cast<To>(f);
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* CreatePrivateFLR(
|
||||
OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
|
||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -41,21 +42,6 @@ Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func,
|
||||
return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle);
|
||||
}
|
||||
|
||||
template <typename To, typename From> // use like this: down_cast<T*>(foo);
|
||||
inline To down_cast(From* f) { // so we only accept pointers
|
||||
static_assert(
|
||||
(std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
|
||||
"target type not derived from source type");
|
||||
|
||||
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
|
||||
#if !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||
// Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
|
||||
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
|
||||
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||
|
||||
return static_cast<To>(f);
|
||||
}
|
||||
|
||||
// If "t" is a scalar of a supported type, returns t != 0 in "*v".
|
||||
Status ToBool(gtl::ArraySlice<Tensor> t, bool* v) {
|
||||
if (t.size() != 1) {
|
||||
|
@ -73,6 +73,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/variable_ops.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -104,7 +105,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) {
|
||||
} else if (ctx->op_device_context() != nullptr) {
|
||||
// TODO(apassos): remove the down_cast by just returning Device* from
|
||||
// OpKernelContext
|
||||
Device* device = static_cast<Device*>(ctx->device());
|
||||
Device* device = down_cast<Device*>(ctx->device());
|
||||
ctx->op_device_context()->CopyTensorInSameDevice(
|
||||
t, device, output, [&n, &status](const Status& s) {
|
||||
status = s;
|
||||
|
@ -158,6 +158,14 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "casts",
|
||||
hdrs = ["casts.h"],
|
||||
deps = [
|
||||
":platform",
|
||||
] + tf_platform_deps("casts"),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cuda",
|
||||
hdrs = ["cuda.h"],
|
||||
@ -1060,6 +1068,7 @@ filegroup(
|
||||
name = "lib_hdrs",
|
||||
srcs = [
|
||||
"abi.h",
|
||||
"casts.h",
|
||||
"context.h",
|
||||
"cpu_feature_guard.h",
|
||||
"cpu_info.h",
|
||||
@ -1254,6 +1263,7 @@ filegroup(
|
||||
"//tensorflow/core/platform:base64.h",
|
||||
"//tensorflow/core/platform:blocking_counter.h",
|
||||
"//tensorflow/core/platform:byte_order.h",
|
||||
"//tensorflow/core/platform:casts.h",
|
||||
"//tensorflow/core/platform:coding.cc",
|
||||
"//tensorflow/core/platform:coding.h",
|
||||
"//tensorflow/core/platform:context.h",
|
||||
@ -1360,6 +1370,7 @@ filegroup(
|
||||
name = "legacy_srcs_no_runtime",
|
||||
srcs = [
|
||||
":legacy_srcs_common",
|
||||
"//tensorflow/core/platform/default:casts.h",
|
||||
"//tensorflow/core/platform/default:context.h",
|
||||
"//tensorflow/core/platform/default:cord.h",
|
||||
"//tensorflow/core/platform/default:dynamic_annotations.h",
|
||||
|
31
tensorflow/core/platform/casts.h
Normal file
31
tensorflow/core/platform/casts.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_CASTS_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_CASTS_H_
|
||||
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#include "tensorflow/core/platform/google/casts.h"
|
||||
#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
|
||||
defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \
|
||||
defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS)
|
||||
#include "tensorflow/core/platform/default/casts.h"
|
||||
#else
|
||||
#error Define the appropriate PLATFORM_<foo> macro for this platform
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_CASTS_H_
|
@ -9,6 +9,16 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "casts",
|
||||
hdrs = ["casts.h"],
|
||||
tags = [
|
||||
"manual",
|
||||
"no_oss",
|
||||
"nobuilder",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "context",
|
||||
hdrs = ["//tensorflow/core/platform:context.h"],
|
||||
|
@ -541,6 +541,7 @@ def tf_proto_library(
|
||||
|
||||
def tf_additional_lib_hdrs():
|
||||
return [
|
||||
"//tensorflow/core/platform/default:casts.h",
|
||||
"//tensorflow/core/platform/default:context.h",
|
||||
"//tensorflow/core/platform/default:cord.h",
|
||||
"//tensorflow/core/platform/default:dynamic_annotations.h",
|
||||
|
92
tensorflow/core/platform/default/casts.h
Normal file
92
tensorflow/core/platform/default/casts.h
Normal file
@ -0,0 +1,92 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_
|
||||
|
||||
#include <assert.h> // for use with down_cast<>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// An "upcast", i.e. a conversion from a pointer to an object to a pointer to a
|
||||
// base subobject, always succeeds if the base is unambiguous and accessible,
|
||||
// and so it's fine to use implicit_cast.
|
||||
//
|
||||
// A "downcast", i.e. a conversion from a pointer to an object to a pointer
|
||||
// to a more-derived object that may contain the original object as a base
|
||||
// subobject, cannot safely be done using static_cast, because you do not
|
||||
// generally know whether the source object is really the base subobject of
|
||||
// a containing, more-derived object of the target type. Thus, when you
|
||||
// downcast in a polymorphic type hierarchy, you should use the following
|
||||
// function template.
|
||||
//
|
||||
// In debug mode, we use dynamic_cast to double-check whether the downcast is
|
||||
// legal (we die if it's not). In normal mode, we do the efficient static_cast
|
||||
// instead. Thus, it's important to test in debug mode to make sure the cast is
|
||||
// legal!
|
||||
//
|
||||
// This is the only place in the codebase we should use dynamic_cast.
|
||||
// In particular, you should NOT use dynamic_cast for RTTI, e.g. for
|
||||
// code like this:
|
||||
// if (auto* p = dynamic_cast<Subclass1*>(foo)) HandleASubclass1Object(p);
|
||||
// if (auto* p = dynamic_cast<Subclass2*>(foo)) HandleASubclass2Object(p);
|
||||
// You should design the code some other way not to need this.
|
||||
|
||||
template <typename To, typename From> // use like this: down_cast<T*>(foo);
|
||||
inline To down_cast(From* f) { // so we only accept pointers
|
||||
static_assert(
|
||||
(std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
|
||||
"target type not derived from source type");
|
||||
|
||||
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
|
||||
#if !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||
// Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
|
||||
assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
|
||||
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||
|
||||
return static_cast<To>(f);
|
||||
}
|
||||
|
||||
// Overload of down_cast for references. Use like this: down_cast<T&>(foo).
|
||||
// The code is slightly convoluted because we're still using the pointer
|
||||
// form of dynamic cast. (The reference form throws an exception if it
|
||||
// fails.)
|
||||
//
|
||||
// There's no need for a special const overload either for the pointer
|
||||
// or the reference form. If you call down_cast with a const T&, the
|
||||
// compiler will just bind From to const T.
|
||||
template <typename To, typename From>
|
||||
inline To down_cast(From& f) {
|
||||
static_assert(std::is_lvalue_reference<To>::value,
|
||||
"target type not a reference");
|
||||
static_assert(
|
||||
(std::is_base_of<From, typename std::remove_reference<To>::type>::value),
|
||||
"target type not derived from source type");
|
||||
|
||||
// We skip the assert and hence the dynamic_cast if RTTI is disabled.
|
||||
#if !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||
// RTTI: debug mode only
|
||||
assert(dynamic_cast<typename std::remove_reference<To>::type*>(&f) !=
|
||||
nullptr);
|
||||
#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
|
||||
|
||||
return static_cast<To>(f);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CASTS_H_
|
Loading…
Reference in New Issue
Block a user