12#include <c10/macros/Macros.h>
13#include <c10/util/C++17.h>
14#include <c10/util/TypeSafeSignMath.h>
15#include <c10/util/complex.h>
18#if defined(__cplusplus) && (__cplusplus >= 201103L)
21#elif !defined(__OPENCL_VERSION__)
45#include <hip/hip_fp16.h>
48#if defined(CL_SYCL_LANGUAGE_VERSION)
50#elif defined(SYCL_LANGUAGE_VERSION)
51#include <sycl/sycl.hpp>
55#if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
56#define C10_DEVICE_HOST_FUNCTION __device__ __host__
58#define C10_DEVICE_HOST_FUNCTION
68#if defined(__OPENCL_VERSION__)
70#elif defined(__CUDA_ARCH__)
71 return __uint_as_float((
unsigned int)w);
72#elif defined(__INTEL_COMPILER)
73 return _castu32_f32(w);
84#if defined(__OPENCL_VERSION__)
86#elif defined(__CUDA_ARCH__)
87 return (uint32_t)__float_as_uint(f);
88#elif defined(__INTEL_COMPILER)
89 return _castf32_u32(f);
118 const uint32_t w = (uint32_t)h << 16;
127 const uint32_t sign = w & UINT32_C(0x80000000);
137 const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
148 unsigned long nonsign_bsr;
149 _BitScanReverse(&nonsign_bsr, (
unsigned long)nonsign);
150 uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
152 uint32_t renorm_shift = __builtin_clz(nonsign);
154 renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
161 const int32_t inf_nan_mask =
162 ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
170 const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
190 ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |
216 const uint32_t w = (uint32_t)h << 16;
225 const uint32_t sign = w & UINT32_C(0x80000000);
235 const uint32_t two_w = w + w;
269 constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
271 constexpr uint32_t scale_bits = (uint32_t)15 << 23;
273 std::memcpy(&exp_scale_val, &scale_bits,
sizeof(exp_scale_val));
274 const float exp_scale = exp_scale_val;
275 const float normalized_value =
309 constexpr uint32_t magic_mask = UINT32_C(126) << 23;
310 constexpr float magic_bias = 0.5f;
311 const float denormalized_value =
322 constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
323 const uint32_t result = sign |
324 (two_w < denormalized_cutoff ?
fp32_to_bits(denormalized_value)
341 constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
342 constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
343 float scale_to_inf_val, scale_to_zero_val;
344 std::memcpy(&scale_to_inf_val, &scale_to_inf_bits,
sizeof(scale_to_inf_val));
346 &scale_to_zero_val, &scale_to_zero_bits,
sizeof(scale_to_zero_val));
347 const float scale_to_inf = scale_to_inf_val;
348 const float scale_to_zero = scale_to_zero_val;
350#if defined(_MSC_VER) && _MSC_VER == 1916
351 float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
353 float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
357 const uint32_t shl1_w = w + w;
358 const uint32_t sign = w & UINT32_C(0x80000000);
359 uint32_t bias = shl1_w & UINT32_C(0xFF000000);
360 if (bias < UINT32_C(0x71000000)) {
361 bias = UINT32_C(0x71000000);
366 const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
367 const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
368 const uint32_t nonsign = exp_bits + mantissa_bits;
369 return static_cast<uint16_t
>(
371 (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
386 C10_HOST_DEVICE
Half() =
default;
392 inline C10_HOST_DEVICE
Half(
float value);
393 inline C10_HOST_DEVICE
operator float()
const;
395#if defined(__CUDACC__) || defined(__HIPCC__)
396 inline C10_HOST_DEVICE
Half(
const __half& value);
397 inline C10_HOST_DEVICE
operator __half()
const;
399#ifdef SYCL_LANGUAGE_VERSION
400 inline C10_HOST_DEVICE
Half(
const sycl::half& value);
401 inline C10_HOST_DEVICE
operator sycl::half()
const;
407struct alignas(4) complex<
Half> {
416 : real_(real), imag_(imag) {}
417 C10_HOST_DEVICE
inline complex(
const c10::complex<float>& value)
418 : real_(value.real()), imag_(value.imag()) {}
421 inline C10_HOST_DEVICE
operator c10::complex<float>()
const {
422 return {real_, imag_};
432 C10_HOST_DEVICE complex<Half>&
operator+=(
const complex<Half>& other) {
433 real_ =
static_cast<float>(real_) +
static_cast<float>(other.real_);
434 imag_ =
static_cast<float>(imag_) +
static_cast<float>(other.imag_);
438 C10_HOST_DEVICE complex<Half>&
operator-=(
const complex<Half>& other) {
439 real_ =
static_cast<float>(real_) -
static_cast<float>(other.real_);
440 imag_ =
static_cast<float>(imag_) -
static_cast<float>(other.imag_);
444 C10_HOST_DEVICE complex<Half>&
operator*=(
const complex<Half>& other) {
445 auto a =
static_cast<float>(real_);
446 auto b =
static_cast<float>(imag_);
447 auto c =
static_cast<float>(other.real());
448 auto d =
static_cast<float>(other.imag());
449 real_ = a * c - b * d;
450 imag_ = a * d + b * c;
461#pragma warning(disable : 4146)
462#pragma warning(disable : 4804)
463#pragma warning(disable : 4018)
470#pragma GCC diagnostic push
471#pragma GCC diagnostic ignored "-Wunknown-warning-option"
472#pragma GCC diagnostic ignored "-Wimplicit-int-float-conversion"
479template <
typename To,
typename From>
480typename std::enable_if<std::is_same<From, bool>::value,
bool>::type
overflows(
486template <
typename To,
typename From>
487typename std::enable_if<
488 std::is_integral<From>::value && !std::is_same<From, bool>::value,
491 using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
492 if (!limit::is_signed && std::numeric_limits<From>::is_signed) {
496 return greater_than_max<To>(f) ||
497 (c10::is_negative(f) && -
static_cast<uint64_t
>(f) > limit::max());
499 return c10::less_than_lowest<To>(f) || greater_than_max<To>(f);
503template <
typename To,
typename From>
504typename std::enable_if<std::is_floating_point<From>::value,
bool>::type
506 using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
507 if (limit::has_infinity && std::isinf(
static_cast<double>(f))) {
510 if (!limit::has_quiet_NaN && (f != f)) {
513 return f < limit::lowest() || f > limit::max();
517#pragma GCC diagnostic pop
524template <
typename To,
typename From>
525typename std::enable_if<is_complex<From>::value,
bool>::type
overflows(From f) {
528 if (!is_complex<To>::value && f.imag() != 0) {
536 typename scalar_value_type<To>::type,
537 typename From::value_type>(f.real()) ||
539 typename scalar_value_type<To>::type,
540 typename From::value_type>(f.imag());
547#include <c10/util/Half-inl.h>
#define C10_DEVICE_HOST_FUNCTION
Defines the Half type (half-precision floating-point) including conversions to standard C types and b...
Definition: Half.h:58
float fp16_ieee_to_fp32_value(uint16_t h)
Definition: Half.h:204
uint32_t fp16_ieee_to_fp32_bits(uint16_t h)
Definition: Half.h:106
uint16_t fp16_ieee_from_fp32_value(float f)
Definition: Half.h:338
uint32_t fp32_to_bits(float f)
Definition: Half.h:83
float fp32_from_bits(uint32_t w)
Definition: Half.h:67
std::ostream & operator<<(std::ostream &stream, const Device &device)
std::enable_if< std::is_same< From, bool >::value, bool >::type overflows(From)
Definition: Half.h:480
unsigned short x
Definition: Half.h:377
constexpr Half(unsigned short bits, from_bits_t)
Definition: Half.h:391
static constexpr from_bits_t from_bits()
Definition: Half.h:380
complex< Half > & operator-=(const complex< Half > &other)
Definition: Half.h:438
complex(const c10::complex< float > &value)
Definition: Half.h:417
constexpr Half imag() const
Definition: Half.h:428
complex(const Half &real, const Half &imag)
Definition: Half.h:415
constexpr Half real() const
Definition: Half.h:425
Half real_
Definition: Half.h:408
complex< Half > & operator*=(const complex< Half > &other)
Definition: Half.h:444
complex< Half > & operator+=(const complex< Half > &other)
Definition: Half.h:432
Half imag_
Definition: Half.h:409