PyTorch
Loading...
Searching...
No Matches
Device.h
Go to the documentation of this file.
1#pragma once
2
4#include <c10/macros/Macros.h>
6
7#include <cstddef>
8#include <functional>
9#include <iosfwd>
10#include <string>
11
12namespace c10 {
13
18using DeviceIndex = int8_t;
19
30struct C10_API Device final {
32
35 /* implicit */ Device(DeviceType type, DeviceIndex index = -1)
36 : type_(type), index_(index) {
37 validate();
38 }
39
45 /* implicit */ Device(const std::string& device_string);
46
49 bool operator==(const Device& other) const noexcept {
50 return this->type_ == other.type_ && this->index_ == other.index_;
51 }
52
55 bool operator!=(const Device& other) const noexcept {
56 return !(*this == other);
57 }
58
60 void set_index(DeviceIndex index) {
61 index_ = index;
62 }
63
65 DeviceType type() const noexcept {
66 return type_;
67 }
68
70 DeviceIndex index() const noexcept {
71 return index_;
72 }
73
75 bool has_index() const noexcept {
76 return index_ != -1;
77 }
78
80 bool is_cuda() const noexcept {
81 return type_ == DeviceType::CUDA;
82 }
83
85 bool is_mps() const noexcept {
86 return type_ == DeviceType::MPS;
87 }
88
90 bool is_hip() const noexcept {
91 return type_ == DeviceType::HIP;
92 }
93
95 bool is_ve() const noexcept {
96 return type_ == DeviceType::VE;
97 }
98
100 bool is_xpu() const noexcept {
101 return type_ == DeviceType::XPU;
102 }
103
105 bool is_ipu() const noexcept {
106 return type_ == DeviceType::IPU;
107 }
108
110 bool is_xla() const noexcept {
111 return type_ == DeviceType::XLA;
112 }
113
115 bool is_hpu() const noexcept {
116 return type_ == DeviceType::HPU;
117 }
118
120 bool is_lazy() const noexcept {
121 return type_ == DeviceType::Lazy;
122 }
123
125 bool is_vulkan() const noexcept {
126 return type_ == DeviceType::Vulkan;
127 }
128
130 bool is_metal() const noexcept {
131 return type_ == DeviceType::Metal;
132 }
133
135 bool is_ort() const noexcept {
136 return type_ == DeviceType::ORT;
137 }
138
140 bool is_meta() const noexcept {
141 return type_ == DeviceType::Meta;
142 }
143
145 bool is_cpu() const noexcept {
146 return type_ == DeviceType::CPU;
147 }
148
150 bool supports_as_strided() const noexcept {
151 return type_ != DeviceType::IPU && type_ != DeviceType::XLA &&
152 type_ != DeviceType::Lazy;
153 }
154
156 std::string str() const;
157
158 private:
159 DeviceType type_;
160 DeviceIndex index_ = -1;
161 void validate() {
162 // Removing these checks in release builds noticeably improves
163 // performance in micro-benchmarks.
164 // This is safe to do, because backends that use the DeviceIndex
165 // have a later check when we actually try to switch to that device.
167 index_ == -1 || index_ >= 0,
168 "Device index must be -1 or non-negative, got ",
169 (int)index_);
171 !is_cpu() || index_ <= 0,
172 "CPU device index must be -1 or zero, got ",
173 (int)index_);
174 }
175};
176
177C10_API std::ostream& operator<<(std::ostream& stream, const Device& device);
178
179} // namespace c10
180
181namespace std {
182template <>
183struct hash<c10::Device> {
184 size_t operator()(c10::Device d) const noexcept {
185 // Are you here because this static assert failed? Make sure you ensure
186 // that the bitmasking code below is updated accordingly!
187 static_assert(sizeof(c10::DeviceType) == 1, "DeviceType is not 8-bit");
188 static_assert(sizeof(c10::DeviceIndex) == 1, "DeviceIndex is not 8-bit");
189 // Note [Hazard when concatenating signed integers]
190 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
191 // We must first convert to a same-sized unsigned type, before promoting to
192 // the result type, to prevent sign extension when any of the values is -1.
193 // If sign extension occurs, you'll clobber all of the values in the MSB
194 // half of the resulting integer.
195 //
196 // Technically, by C/C++ integer promotion rules, we only need one of the
197 // uint32_t casts to the result type, but we put in both for explicitness's
198 // sake.
199 uint32_t bits = static_cast<uint32_t>(static_cast<uint8_t>(d.type()))
200 << 16 |
201 static_cast<uint32_t>(static_cast<uint8_t>(d.index()));
202 return std::hash<uint32_t>{}(bits);
203 }
204};
205} // namespace std
#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...)
Definition: Exception.h:534
Definition: ivalue.h:27
std::ostream & operator<<(std::ostream &stream, const Device &device)
int8_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
Definition: Device.h:18
DeviceType
Definition: DeviceType.h:33
Definition: Device.h:181
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
bool operator!=(const Device &other) const noexcept
Returns true if the type or index of this Device differs from that of other.
Definition: Device.h:55
bool is_vulkan() const noexcept
Return true if the device is of Vulkan type.
Definition: Device.h:125
void set_index(DeviceIndex index)
Sets the device index.
Definition: Device.h:60
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
Definition: Device.h:80
bool has_index() const noexcept
Returns true if the device has a non-default index.
Definition: Device.h:75
bool supports_as_strided() const noexcept
Return true if the device supports arbirtary strides.
Definition: Device.h:150
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70
bool is_xla() const noexcept
Return true if the device is of XLA type.
Definition: Device.h:110
bool is_mps() const noexcept
Return true if the device is of MPS type.
Definition: Device.h:85
bool operator==(const Device &other) const noexcept
Returns true if the type and index of this Device matches that of other.
Definition: Device.h:49
bool is_ve() const noexcept
Return true if the device is of VE type.
Definition: Device.h:95
bool is_lazy() const noexcept
Return true if the device is of Lazy type.
Definition: Device.h:120
bool is_metal() const noexcept
Return true if the device is of Metal type.
Definition: Device.h:130
bool is_hpu() const noexcept
Return true if the device is of HPU type.
Definition: Device.h:115
bool is_meta() const noexcept
Return true if the device is of META type.
Definition: Device.h:140
bool is_cpu() const noexcept
Return true if the device is of CPU type.
Definition: Device.h:145
bool is_hip() const noexcept
Return true if the device is of HIP type.
Definition: Device.h:90
bool is_xpu() const noexcept
Return true if the device is of XPU type.
Definition: Device.h:100
bool is_ort() const noexcept
Return true if the device is of ORT type.
Definition: Device.h:135
bool is_ipu() const noexcept
Return true if the device is of IPU type.
Definition: Device.h:105
DeviceType type() const noexcept
Returns the type of device this is.
Definition: Device.h:65
Device(const std::string &device_string)
Constructs a Device from a string description, for convenience.
Device(DeviceType type, DeviceIndex index=-1)
Constructs a new Device from a DeviceType and an optional device index.
Definition: Device.h:35
std::string str() const
Same string as returned from operator<<.