Rolling back PR #44894 because it caused type check errors (got bfloat16 when expecting float).
PiperOrigin-RevId: 346214899 Change-Id: I291592f14a7e3b087e1c26d29d7ecdaef4bc2fed
This commit is contained in:
parent
4764f167c2
commit
bb59188213
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/sparse_xent_op.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -123,8 +122,6 @@ REGISTER(CPU, float, int32)
|
||||
REGISTER(CPU, float, int64)
|
||||
REGISTER(CPU, double, int32)
|
||||
REGISTER(CPU, double, int64)
|
||||
REGISTER(CPU, bfloat16, int32)
|
||||
REGISTER(CPU, bfloat16, int64)
|
||||
REGISTER(CPU, Eigen::half, int32)
|
||||
REGISTER(CPU, Eigen::half, int64)
|
||||
|
||||
|
@ -23,9 +23,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) {
|
||||
static Graph* SparseXent(int batch_size, int num_classes) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor logits(value_type, TensorShape({batch_size, num_classes}));
|
||||
Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes}));
|
||||
logits.flat<float>().setRandom();
|
||||
Tensor labels(DT_INT64, TensorShape({batch_size}));
|
||||
std::random_device rd;
|
||||
@ -41,45 +41,44 @@ static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) {
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_SparseXentDev(BATCH, CLASS, DEVICE, DTYPE) \
|
||||
static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE( \
|
||||
#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \
|
||||
static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE( \
|
||||
::testing::benchmark::State& state) { \
|
||||
test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS, DTYPE), \
|
||||
test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS), \
|
||||
/*old_benchmark_api*/ false) \
|
||||
.Run(state); \
|
||||
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \
|
||||
CLASS); \
|
||||
} \
|
||||
BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE);
|
||||
|
||||
#define BM_SPARSE_XENT_DEV_CPU(DTYPE) \
|
||||
BM_SparseXentDev(8, 1000000, cpu, DTYPE); \
|
||||
BM_SparseXentDev(16, 10000, cpu, DTYPE); \
|
||||
BM_SparseXentDev(16, 100000, cpu, DTYPE); \
|
||||
BM_SparseXentDev(32, 10000, cpu, DTYPE); \
|
||||
BM_SparseXentDev(32, 100000, cpu, DTYPE); \
|
||||
BM_SparseXentDev(64, 10000, cpu, DTYPE); \
|
||||
BM_SparseXentDev(64, 100000, cpu, DTYPE);
|
||||
|
||||
// CPU
|
||||
BM_SPARSE_XENT_DEV_CPU(DT_FLOAT);
|
||||
BM_SPARSE_XENT_DEV_CPU(DT_BFLOAT16);
|
||||
BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE);
|
||||
|
||||
/// The representative tests for ptb_word on GPU
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
BM_SparseXentDev(8, 1000000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(8, 1000000, gpu);
|
||||
|
||||
BM_SparseXentDev(16, 10000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(16, 30000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(16, 100000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(16, 10000, gpu);
|
||||
BM_SparseXentDev(16, 30000, gpu);
|
||||
BM_SparseXentDev(16, 100000, gpu);
|
||||
|
||||
BM_SparseXentDev(32, 10000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(32, 30000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(32, 100000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(32, 10000, gpu);
|
||||
BM_SparseXentDev(32, 30000, gpu);
|
||||
BM_SparseXentDev(32, 100000, gpu);
|
||||
|
||||
BM_SparseXentDev(64, 10000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(64, 30000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(64, 100000, gpu, DT_FLOAT);
|
||||
BM_SparseXentDev(64, 10000, gpu);
|
||||
BM_SparseXentDev(64, 30000, gpu);
|
||||
BM_SparseXentDev(64, 100000, gpu);
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
// CPU
|
||||
BM_SparseXentDev(8, 1000000, cpu);
|
||||
|
||||
BM_SparseXentDev(16, 10000, cpu);
|
||||
BM_SparseXentDev(16, 100000, cpu);
|
||||
|
||||
BM_SparseXentDev(32, 10000, cpu);
|
||||
BM_SparseXentDev(32, 100000, cpu);
|
||||
|
||||
BM_SparseXentDev(64, 10000, cpu);
|
||||
BM_SparseXentDev(64, 100000, cpu);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -182,23 +182,6 @@ class SparseXentTest(test.TestCase):
|
||||
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
|
||||
np.array([0, 3]).astype(label_dtype))
|
||||
|
||||
def testBfloat16(self):
|
||||
for label_dtype in np.int32, np.int64:
|
||||
np_features = np.array([[1., 1., 1., 1.], [1., 2., 3.,
|
||||
4.]]).astype(np.float32)
|
||||
np_labels = np.array([0, 3]).astype(label_dtype)
|
||||
np_loss, np_backprop = self._npXent(np_features, np_labels)
|
||||
|
||||
np_features_bf16 = math_ops.cast(np_features, dtypes.bfloat16)
|
||||
np_loss_bf16 = math_ops.cast(np_loss, dtypes.bfloat16)
|
||||
np_backprop_bf16 = math_ops.cast(np_backprop, dtypes.bfloat16)
|
||||
with self.cached_session(use_gpu=False):
|
||||
loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
|
||||
np_features_bf16, np_labels)
|
||||
tf_loss, tf_backprop = self.evaluate([loss, backprop])
|
||||
self.assertAllCloseAccordingToType(np_loss_bf16, tf_loss)
|
||||
self.assertAllCloseAccordingToType(np_backprop_bf16, tf_backprop)
|
||||
|
||||
def testHalf(self):
|
||||
for label_dtype in np.int32, np.int64:
|
||||
self._testXent(
|
||||
|
Loading…
Reference in New Issue
Block a user