4#include <c10/macros/Macros.h>
36 : type_(type), index_(index) {
45 Device(
const std::string& device_string);
50 return this->type_ == other.type_ && this->index_ == other.index_;
56 return !(*
this == other);
81 return type_ == DeviceType::CUDA;
86 return type_ == DeviceType::MPS;
91 return type_ == DeviceType::HIP;
96 return type_ == DeviceType::VE;
101 return type_ == DeviceType::XPU;
106 return type_ == DeviceType::IPU;
111 return type_ == DeviceType::XLA;
116 return type_ == DeviceType::HPU;
121 return type_ == DeviceType::Lazy;
126 return type_ == DeviceType::Vulkan;
131 return type_ == DeviceType::Metal;
136 return type_ == DeviceType::ORT;
141 return type_ == DeviceType::Meta;
146 return type_ == DeviceType::CPU;
151 return type_ != DeviceType::IPU && type_ != DeviceType::XLA &&
152 type_ != DeviceType::Lazy;
167 index_ == -1 || index_ >= 0,
168 "Device index must be -1 or non-negative, got ",
171 !is_cpu() || index_ <= 0,
172 "CPU device index must be -1 or zero, got ",
183struct hash<
c10::Device> {
187 static_assert(
sizeof(
c10::DeviceType) == 1,
"DeviceType is not 8-bit");
199 uint32_t bits =
static_cast<uint32_t
>(
static_cast<uint8_t
>(d.
type()))
201 static_cast<uint32_t
>(
static_cast<uint8_t
>(d.
index()));
202 return std::hash<uint32_t>{}(bits);
#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...)
Definition: Exception.h:534
std::ostream & operator<<(std::ostream &stream, const Device &device)
int8_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
Definition: Device.h:18
DeviceType
Definition: DeviceType.h:33
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
bool operator!=(const Device &other) const noexcept
Returns true if the type or index of this Device differs from that of other.
Definition: Device.h:55
bool is_vulkan() const noexcept
Return true if the device is of Vulkan type.
Definition: Device.h:125
void set_index(DeviceIndex index)
Sets the device index.
Definition: Device.h:60
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
Definition: Device.h:80
bool has_index() const noexcept
Returns true if the device has a non-default index.
Definition: Device.h:75
bool supports_as_strided() const noexcept
Return true if the device supports arbirtary strides.
Definition: Device.h:150
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70
bool is_xla() const noexcept
Return true if the device is of XLA type.
Definition: Device.h:110
bool is_mps() const noexcept
Return true if the device is of MPS type.
Definition: Device.h:85
bool operator==(const Device &other) const noexcept
Returns true if the type and index of this Device matches that of other.
Definition: Device.h:49
bool is_ve() const noexcept
Return true if the device is of VE type.
Definition: Device.h:95
bool is_lazy() const noexcept
Return true if the device is of Lazy type.
Definition: Device.h:120
bool is_metal() const noexcept
Return true if the device is of Metal type.
Definition: Device.h:130
bool is_hpu() const noexcept
Return true if the device is of HPU type.
Definition: Device.h:115
bool is_meta() const noexcept
Return true if the device is of META type.
Definition: Device.h:140
bool is_cpu() const noexcept
Return true if the device is of CPU type.
Definition: Device.h:145
bool is_hip() const noexcept
Return true if the device is of HIP type.
Definition: Device.h:90
bool is_xpu() const noexcept
Return true if the device is of XPU type.
Definition: Device.h:100
bool is_ort() const noexcept
Return true if the device is of ORT type.
Definition: Device.h:135
bool is_ipu() const noexcept
Return true if the device is of IPU type.
Definition: Device.h:105
DeviceType type() const noexcept
Returns the type of device this is.
Definition: Device.h:65
Device(const std::string &device_string)
Constructs a Device from a string description, for convenience.
Device(DeviceType type, DeviceIndex index=-1)
Constructs a new Device from a DeviceType and an optional device index.
Definition: Device.h:35
std::string str() const
Same string as returned from operator<<.