PyTorch
Loading...
Searching...
No Matches
TensorBase.h
Go to the documentation of this file.
1#pragma once
2
3#include <c10/core/Device.h>
4#include <c10/core/Layout.h>
5#include <c10/core/MemoryFormat.h>
6#include <c10/core/ScalarType.h>
7#include <c10/core/ScalarTypeToTypeMeta.h>
8#include <c10/core/Storage.h>
9#include <c10/core/TensorImpl.h>
10#include <c10/core/TensorOptions.h>
11#include <c10/core/UndefinedTensorImpl.h>
12#include <c10/core/WrapDimMinimal.h>
13#include <c10/util/Exception.h>
14#include <c10/util/ExclusivelyOwnedTensorTraits.h>
15#include <c10/util/MaybeOwned.h>
16#include <c10/util/Optional.h>
17#include <c10/util/intrusive_ptr.h>
18
19#include <ATen/core/NamedTensor.h>
20#include <ATen/core/QuantizerBase.h>
21#include <c10/core/SymIntArrayRef.h>
22#include <ATen/core/TensorAccessor.h>
23
24namespace c10 {
25class Scalar;
26}
27
28namespace torch { namespace autograd {
29
30struct Node;
31
32}} // namespace torch::autograd
33
34namespace at {
35
36class Tensor;
37class TensorBase;
38
39// Convert Tensor to TensorBase without any need to include Tensor.h
40TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
41
42namespace impl {
44#ifdef C10_MOBILE
45 // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
46 return true;
47#else
48 return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
49#endif
50}
51
52}
53
54// NOTE: [Tensor vs. TensorBase]
55//
56// Tensor, being the central data structure in PyTorch, gets used and
57// it's header included almost everywhere. Unfortunately this means
58// every time an operator signature is updated or changed in
59// native_functions.yaml, you (and every other PyTorch developer) need
60// to recompile all of ATen and it's dependencies.
61//
62// TensorBase aims to break up these header dependencies, and improve
63// incremental build times for all PyTorch developers. TensorBase
64// represents a reference counted handle to TensorImpl, exactly the
65// same as Tensor. However, TensorBase doesn't have code generated
66// methods in it's API and thus no dependence on native_functions.yaml.
67//
68// Usage tips
69// ----------
70// - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
71// or .cu file to ensure it has no header dependencies on
72// native_functions.yaml (direct or indirect).
73// - Tensor inherits from TensorBase, so functions taking
74// `const TensorBase &` are callable with Tensor as well.
75// - TensorBase can be converted to tensor with `Tensor(tensor_base)`,
76// but this requires a reference-count bump. OptionalTensorRef on
77// the other hand can materialize a `const Tensor &` without
78// touching the reference-count.
79class TORCH_API TensorBase {
80 public:
81 struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
82
83 protected:
84 // Create a Tensor with a +0 reference count. Special care must be
85 // taken to avoid decrementing this reference count at destruction
86 // time. Intended to support MaybeOwnedTraits<Tensor>.
88 : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {}
89 friend MaybeOwnedTraits<TensorBase>;
90
91 public:
92 TensorBase() = default;
93 // This constructor should not be used by end users and is an implementation
94 // detail invoked by autogenerated code.
95 explicit TensorBase(
96 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
97 : impl_(std::move(tensor_impl)) {
98 if (impl_.get() == nullptr) {
99 throw std::runtime_error("TensorImpl with nullptr is not supported");
100 }
101 }
102 TensorBase(const TensorBase&) = default;
103 TensorBase(TensorBase&&) = default;
104
105 public:
106 // Creates a new wrapper from TensorImpl. Intentionally a free method because
107 // it should be used with care. Checks necessary invariants
109 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
110 TensorBase r(std::move(tensor_impl));
112 return r;
113 }
114
115 int64_t dim() const {
116 return impl_->dim();
117 }
118 int64_t storage_offset() const {
119 return impl_->storage_offset();
120 }
121
122 TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
123 if (is_contiguous(memory_format)) {
124 return *this;
125 } else {
126 return __dispatch_contiguous(memory_format);
127 }
128 }
129
136 c10::MaybeOwned<TensorBase> expect_contiguous(
137 MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
138
139 // Use .contiguous() instead. Trying to borrow from a prvalue
140 // will only lead to trouble and dangling references.
141 c10::MaybeOwned<TensorBase> expect_contiguous(
142 MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
143
144 const TensorBase& fill_(const c10::Scalar& scalar) const;
145 const TensorBase& zero_() const;
146
147 TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, c10::optional<at::MemoryFormat> memory_format=c10::nullopt) const;
148
149 bool is_complex() const {
150 return at::isComplexType(this->scalar_type());
151 }
152
153 bool is_floating_point() const {
154 return at::isFloatingType(this->scalar_type());
155 }
156
157 bool is_signed() const {
158 return at::isSignedType(this->scalar_type());
159 }
160
161 c10::SymInt sym_size(int64_t dim) const {
162 return impl_->sym_size(dim);
163 }
164
165 c10::SymInt sym_stride(int64_t dim) const {
166 const auto sizes = this->sym_strides();
167 const auto ndim = static_cast<int64_t>(sizes.size());
168 // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
169 return sizes[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
170
171 }
172
173 int64_t size(int64_t dim) const {
174 return impl_->size(dim);
175 }
176
177 int64_t stride(int64_t dim) const {
178 const auto strides = this->strides();
179 const auto ndim = static_cast<int64_t>(strides.size());
180 // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
181 return strides[c10::maybe_wrap_dim(dim, ndim, /*wrap_scalar=*/false)];
182 }
183
184 TensorImpl * unsafeGetTensorImpl() const {
185 return impl_.get();
186 }
187 TensorImpl * unsafeReleaseTensorImpl() {
188 return impl_.release();
189 }
190 const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
191 return impl_;
192 }
193
194 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() {
195 return std::move(impl_);
196 }
197
198 bool defined() const {
199 return impl_;
200 }
201
202 void reset() {
203 impl_.reset();
204 }
205
207 impl_ = x.impl_;
208 return *this;
209 };
210 TensorBase& operator=(TensorBase&& x) & noexcept {
211 impl_ = std::move(x.impl_);
212 return *this;
213 }
214
215 // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
216 TensorBase& operator=(const TensorBase&) && = delete;
217 TensorBase& operator=(TensorBase&&) && noexcept = delete;
218
219 bool is_same(const TensorBase& other) const noexcept {
220 return impl_ == other.impl_;
221 }
222 size_t use_count() const noexcept {
223 return impl_.use_count();
224 }
225 size_t weak_use_count() const noexcept {
226 return impl_.weak_use_count();
227 }
228
229 std::string toString() const;
230
231 IntArrayRef sizes() const {
232 return impl_->sizes();
233 }
234 c10::SymIntArrayRef sym_sizes() const {
235 return impl_->sym_sizes();
236 }
237 c10::SymIntArrayRef sym_strides() const {
238 return impl_->sym_strides();
239 }
240 IntArrayRef strides() const {
241 return impl_->strides();
242 }
243 // See impl::get_opt_names in ATen/NamedTensor.h for docs.
245 return impl::get_opt_names(unsafeGetTensorImpl());
246 }
247 // See impl::get_names in ATen/NamedTensor.h for docs.
248 DimnameList names() const {
249 return impl::get_names(unsafeGetTensorImpl());
250 }
251 int64_t ndimension() const {
252 return dim();
253 }
254
255 bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
256 return impl_->is_contiguous(memory_format);
257 }
258
260 return impl_->is_non_overlapping_and_dense();
261 }
262
263 at::MemoryFormat suggest_memory_format(
264 bool channels_last_strides_exact_match = false) const {
265 // Setting channels_last_strides_exact_match to true forces function to
266 // check 0,1 - sized dimension strides.
267 if (layout() == at::kStrided) {
268 if (impl_->is_strides_like_channels_last()) {
269 if (!channels_last_strides_exact_match ||
270 get_channels_last_strides_2d(sizes()) == strides()) {
271 return at::MemoryFormat::ChannelsLast;
272 }
273 }
274 else if (impl_->is_strides_like_channels_last_3d()) {
275 if (!channels_last_strides_exact_match ||
276 get_channels_last_strides_3d(sizes()) == strides()) {
277 return at::MemoryFormat::ChannelsLast3d;
278 }
279 }
280 }
281 return at::MemoryFormat::Contiguous;
282 }
283
284 // Total bytes consumed by the "view" of elements of the array. Does not
285 // include size of metadata. The number reported here does not necessarily
286 // correspond to the true physical memory consumed by a tensor; instead,
287 // it reports the memory the tensor would take *if* it were contiguous.
288 // Defined to be numel() * itemsize()
289 size_t nbytes() const {
290 TORCH_CHECK(layout () != at::kSparse,
291 "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
292 "tensors, add the nbytes of the indices and values. If you want the size of the " \
293 "equivalent dense tensor, multiply numel() by element_size()");
294 return impl_->numel() * impl_->itemsize();
295 }
296
297 c10::SymInt sym_nbytes() const {
298 TORCH_CHECK(layout () != at::kSparse,
299 "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
300 "tensors, add the nbytes of the indices and values. If you want the size of the " \
301 "equivalent dense tensor, multiply numel() by element_size()");
302 return impl_->sym_numel() * impl_->itemsize();
303 }
304
305 int64_t numel() const {
306 return impl_->numel();
307 }
308
309 c10::SymInt sym_numel() const {
310 return impl_->sym_numel();
311 }
312
313 c10::SymInt sym_storage_offset() const {
314 return impl_->sym_storage_offset();
315 }
316
317 // Length of one array element in bytes. This is the traditional
318 // Numpy naming.
319 size_t itemsize() const {
320 return impl_->itemsize();
321 }
322
323 // Same as itemsize(). This is the PyTorch naming.
324 int64_t element_size() const {
325 return static_cast<int64_t>(impl_->itemsize());
326 }
327
328 DispatchKeySet key_set() const {
329 return impl_->key_set();
330 }
331 ScalarType scalar_type() const {
332 return typeMetaToScalarType(impl_->dtype());
333 }
334 bool has_storage() const {
335 return defined() && impl_->has_storage();
336 }
337 const Storage& storage() const {
338 return impl_->storage();
339 }
340 bool is_alias_of(const at::TensorBase& other) const{
341 return impl_->storage().is_alias_of(other.storage());
342 }
343
344 inline bool _is_zerotensor() const {
345 return impl_->_is_zerotensor();
346 }
347
348 inline void _set_zero(bool zero) const {
349 impl_->_set_zero(zero);
350 }
351
352 inline bool is_conj() const {
353 return impl_->is_conj();
354 }
355
356 // sets the conjugate bit of a tensor.
357 // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
358 // that's what you want. Changing this might lead to incorrect behavior since conjugation is
359 // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
360 inline void _set_conj(bool conjugate) const {
361 impl_->_set_conj(conjugate);
362 }
363
364 inline bool is_neg() const {
365 return impl_->is_neg();
366 }
367
368 // sets the negative bit of a tensor.
369 // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
370 // that's what you want. Changing this might lead to incorrect behavior since we rely on this
371 // bit to determine if a negation needs to be materialized.
372 inline void _set_neg(bool negative) const {
373 impl_->_set_neg(negative);
374 }
375
377 Layout layout() const {
378 return impl_->layout();
379 }
380
382 caffe2::TypeMeta dtype() const {
383 return impl_->dtype();
384 }
385
387 inline Device device() const {
388 return impl_->device();
389 }
390
392 int64_t get_device() const {
393 // NB: this is not a native function to avoid dispatching overhead.
394 return impl_->get_device();
395 }
396
398 bool is_cpu() const {
399 // NB: this is not a native function to avoid dispatching overhead.
400 return impl_->is_cpu();
401 }
402
404 bool is_cuda() const {
405 // NB: this is not a native function to avoid dispatching overhead.
406 return impl_->is_cuda();
407 }
408
410 bool is_ipu() const {
411 // NB: this is not a native function to avoid dispatching overhead.
412 return impl_->is_ipu();
413 }
414
416 bool is_xpu() const {
417 // NB: this is not a native function to avoid dispatching overhead.
418 return impl_->is_xpu();
419 }
420
422 bool is_xla() const {
423 return impl_->is_xla();
424 }
425
427 bool is_hpu() const {
428 return impl_->is_hpu();
429 }
430
432 bool is_lazy() const {
433 return impl_->is_lazy();
434 }
435
437 bool is_hip() const {
438 // NB: this is not a native function to avoid dispatching overhead.
439 return impl_->is_hip();
440 }
441
443 bool is_ve() const {
444 // NB: this is not a native function to avoid dispatching overhead.
445 return impl_->is_ve();
446 }
447
449 bool is_sparse() const {
450 // NB: this is not a native function to avoid dispatching overhead.
451 return impl_->is_sparse();
452 }
453
455 bool is_sparse_csr() const {
456 // NB: this is not a native function to avoid dispatching overhead.
457 return impl_->is_sparse_csr();
458 }
459
461 bool is_mkldnn() const {
462 // NB: this is not a native function to avoid dispatching overhead.
463 return impl_->is_mkldnn();
464 }
465
467 bool is_mps() const {
468 // NB: this is not a native function to avoid dispatching overhead.
469 return impl_->is_mps();
470 }
471
473 bool is_ort() const {
474 // NB: this is not a native function to avoid dispatching overhead.
475 return impl_->is_ort();
476 }
477
479 bool is_vulkan() const {
480 // NB: this is not a native function to avoid dispatching overhead.
481 return impl_->is_vulkan();
482 }
483
485 bool is_metal() const {
486 // NB: this is not a native function to avoid dispatching overhead.
487 return impl_->is_metal();
488 }
489
491 bool is_quantized() const {
492 // NB: this is not a native function to avoid dispatching overhead.
493 return impl_->is_quantized();
494 }
495
498 bool is_meta() const {
499 return impl_->is_meta();
500 }
501
503 bool is_inference() const {
504 return impl_->is_inference();
505 }
506
507 // Returns if a `Tensor` is a NestedTensor.
508 bool is_nested() const {
509 return impl_->is_nested();
510 }
511
514 QuantizerPtr quantizer() const;
515
517 bool has_names() const {
518 // If a user is using unnamed tensors, then we can short-circuit right here.
519 // Otherwise, impl::has_names attempts to retrieve names.
520 if (!impl_->has_named_tensor_meta()) {
521 return false;
522 }
523 return impl::has_names(unsafeGetTensorImpl());
524 }
525
527 const NamedTensorMeta* get_named_tensor_meta() const {
528 return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
529 }
530
531 NamedTensorMeta* get_named_tensor_meta() {
532 return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
533 }
534
537 TensorOptions options() const {
538 return TensorOptions().dtype(dtype())
539 .device(device())
540 .layout(layout());
541 }
542
543 void* data_ptr() const {
544 return this->unsafeGetTensorImpl()->data();
545 }
546
547 template <typename T>
548 T * data_ptr() const;
549
550 // Purposely not defined here to avoid inlining
551 void print() const;
552
553 // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
554 // dimension.
555 template<typename T, size_t N>
556 TensorAccessor<T,N> accessor() const& {
557 static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
558 TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
559 return TensorAccessor<T,N>(data_ptr<T>(),sizes().data(),strides().data());
560 }
561 template<typename T, size_t N>
562 TensorAccessor<T,N> accessor() && = delete;
563
564 // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
565 // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
566 // cast the data pointer to a __restrict__ pointer.
567 // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
568 // as an argument.
569 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
570 GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
571 static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
572 TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
573 return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(data_ptr<T>()),sizes().data(),strides().data());
574 }
575 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
576 GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
577
578 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
579 PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
581 impl_->numel() <=
582 static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
583 "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
584 return generic_packed_accessor<T,N,PtrTraits,int32_t>();
585 }
586 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
587 PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
588
589 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
590 PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
591 return generic_packed_accessor<T,N,PtrTraits,int64_t>();
592 }
593 template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
594 PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
595
596 // ~~~~~ Autograd API ~~~~~
597
634
670
675
680
685
690
691 const TensorBase& set_requires_grad(bool requires_grad) const {
692 impl_->set_requires_grad(requires_grad);
693 return *this;
694 }
695 bool requires_grad() const {
696 return impl_->requires_grad();
697 }
698
699 // The Forward AD API functions below are low level and are not to be used by end
700 // users who should use the API provided in torch/csrc/autograd.h
701
703 const Tensor& _fw_grad(uint64_t level) const {
704 return impl_->_fw_grad(level, *this);
705 }
706
711 void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
712 impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
713 }
714
724
737
738 // Gradient Node and Edges
739 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
740
748 const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
749
750 // Hooks
751 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
752
753 template <typename T>
754 using hook_return_void_t = std::enable_if_t<std::is_void<typename c10::invoke_result_t<T&, TensorBase>>::value, unsigned>;
755 template <typename T>
756 using hook_return_var_t = std::enable_if_t<std::is_same<typename c10::invoke_result_t<T&, TensorBase>, TensorBase>::value, unsigned>;
757
788 template <typename T>
790 template <typename T>
792
793protected:
794 unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const;
795
796public:
797
799 void remove_hook(unsigned pos) const;
800
801 // Variable methods
802 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
803
804 bool is_leaf() const;
805
806 int64_t output_nr() const;
807
808 void set_data(const TensorBase & new_data) const;
809
811
812 int64_t _version() const;
813
814 void retain_grad() const;
815
816 bool retains_grad() const;
817
818 const TensorBase& requires_grad_(bool _requires_grad=true) const;
819
820 // View Variables
821 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
822
824 bool is_view() const;
825
828 const TensorBase& _base() const;
829
830 // Miscellaneous
831 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
832
833 const std::string& name() const;
834
835protected:
837 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
838
839private:
840 TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
841};
842
843inline int64_t get_device(const TensorBase& self) {
844 return self.get_device();
845}
846
847template <typename T>
849 // Return the grad argument in case of a hook with void return type to have an
850 // std::function with Tensor return type
851 static_assert(std::is_same<decltype(hook(TensorBase())), void>::value,
852 "Expected hook to return void");
853 return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) {
854 fn(grad);
855 return TensorBase();
856 });
857}
858
859template <typename T>
861 return _register_hook(std::forward<T>(hook));
862}
863
864namespace detail {
865// Helper creator for Tensor class which doesn't requires the users to pass
866// in an intrusive_ptr instead it just converts the argument passed to
867// requested intrusive_ptr type.
868template <typename T, typename... Args>
870 return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
871}
872
873} // namespace detail
874
875static inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
877}
878
879} // namespace at
880
881namespace c10 {
882template <>
883struct MaybeOwnedTraits<at::TensorBase> {
886
887 static borrow_type createBorrow(const owned_type& from) {
888 // NOTE: this can be implemented without the special
889 // unsafe_borrow_t Tensor constructor as
890 //
891 // return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl()));
892 //
893 // but that hurts inlining due to the nullptr check in the
894 // Tensor(c10::intrusive_ptr<...>) constructor. We already know
895 // that from.impl_ isn't null because from is a valid Tensor, so
896 // we needn't do the check again. (using __builtin_assume can
897 // avoid this, but wouldn't be portable to MSVC.)
899 }
900
901 static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
903 // See above note: this can be implemented with public API
904 // similarly to createBorrow(), but that would hurt inlining.
906 }
907
908 static void destroyBorrow(borrow_type& toDestroy) {
909 toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
910 }
911
912 static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
913 return borrow;
914 }
915
916 static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
917 return &borrow;
918 }
919
920 static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
921 return true;
922 }
923};
924
925template <>
926struct ExclusivelyOwnedTraits<at::TensorBase> : public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {};
927} // namespace c10
928
929namespace at {
930
931inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor(
932 const c10::optional<TensorBase>& opt) {
933 return opt.has_value()
934 ? c10::MaybeOwned<TensorBase>::borrowed(*opt)
935 : c10::MaybeOwned<TensorBase>::owned(c10::in_place);
936}
937
938inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
939 if (is_contiguous(memory_format)) {
940 return c10::MaybeOwned<TensorBase>::borrowed(*this);
941 } else {
942 return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
943 }
944}
945
946namespace symint {
947
948template <typename T>
949using enable_if_symint = std::enable_if_t<std::is_same<T, c10::SymInt>::value>;
950template <typename T>
951using enable_if_int = std::enable_if_t<std::is_same<T, int64_t>::value>;
952
953template <typename T, typename = enable_if_symint<T>>
954c10::SymIntArrayRef sizes(const TensorBase& t) { return t.sym_sizes(); }
955template <typename T, typename = enable_if_int<T>>
956IntArrayRef sizes(const TensorBase& t) { return t.sizes(); }
957
958template <typename T, typename = enable_if_symint<T>>
959c10::SymInt size(const TensorBase& t, int64_t dim) { return t.sym_size(dim); }
960template <typename T, typename = enable_if_int<T>>
961int64_t size(const TensorBase& t, int64_t dim) { return t.size(dim); }
962
963template <typename T, typename = enable_if_symint<T>>
964c10::SymIntArrayRef strides(const TensorBase& t) { return t.sym_strides(); }
965template <typename T, typename = enable_if_int<T>>
966IntArrayRef strides(const TensorBase& t) { return t.strides(); }
967
968template <typename T, typename = enable_if_symint<T>>
969c10::SymInt numel(const TensorBase& t) { return t.sym_numel(); }
970template <typename T, typename = enable_if_int<T>>
971int64_t numel(const TensorBase& t) { return t.numel(); }
972
973} // namespace symint
974
975} // namespace at
#define TORCH_CHECK(cond,...)
Definition: Exception.h:505
Definition: TensorBase.h:79
IntArrayRef strides() const
Definition: TensorBase.h:240
void retain_grad() const
Enables this Tensor to have their :attr:grad populated during :func:backward.
bool is_sparse() const
Returns if a Tensor has sparse backend.
Definition: TensorBase.h:449
bool is_inference() const
Returns if a Tensor is an inference tensor.
Definition: TensorBase.h:503
const std::string & name() const
bool is_signed() const
Definition: TensorBase.h:157
void _set_conj(bool conjugate) const
Definition: TensorBase.h:360
bool defined() const
Definition: TensorBase.h:198
c10::SymIntArrayRef sym_sizes() const
Definition: TensorBase.h:234
int64_t output_nr() const
c10::SymInt sym_size(int64_t dim) const
Definition: TensorBase.h:161
int64_t element_size() const
Definition: TensorBase.h:324
bool requires_grad() const
Definition: TensorBase.h:695
c10::optional< DimnameList > opt_names() const
Definition: TensorBase.h:244
bool is_cuda() const
Returns if a Tensor has CUDA backend.
Definition: TensorBase.h:404
c10::SymInt sym_stride(int64_t dim) const
Definition: TensorBase.h:165
int64_t storage_offset() const
Definition: TensorBase.h:118
const Tensor & _fw_grad(uint64_t level) const
This function returns the forward gradient for this Tensor at the given level.
Definition: TensorBase.h:703
TensorBase & operator=(TensorBase &&) &&noexcept=delete
bool is_metal() const
Returns if a Tensor is metal tensor.
Definition: TensorBase.h:485
PackedTensorAccessor64< T, N, PtrTraits > packed_accessor64() &&=delete
at::TensorBase tensor_data() const
NOTE: This is similar to the legacy .data() function on Variable, and is intended to be used from fun...
unsigned _register_hook(std::function< TensorBase(const TensorBase &)> hook) const
bool has_storage() const
Definition: TensorBase.h:334
bool is_non_overlapping_and_dense() const
Definition: TensorBase.h:259
bool is_cpu() const
Returns if a Tensor has CPU backend.
Definition: TensorBase.h:398
QuantizerPtr quantizer() const
If a tensor is a quantized tensor, returns its quantizer TODO: it's not in native_functions....
bool has_names() const
Returns if a Tensor has any dimension names.
Definition: TensorBase.h:517
void print() const
void _set_fw_grad(const TensorBase &new_grad, uint64_t level, bool is_inplace_op) const
This function can be used to set the value of the forward grad.
Definition: TensorBase.h:711
bool is_floating_point() const
Definition: TensorBase.h:153
TensorBase & operator=(TensorBase &&x) &noexcept
Definition: TensorBase.h:210
const std::shared_ptr< torch::autograd::Node > & grad_fn() const
Gets the gradient function of the Variable.
TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const
Definition: TensorBase.h:122
TensorBase & operator=(const TensorBase &) &&=delete
c10::SymInt sym_numel() const
Definition: TensorBase.h:309
std::enable_if_t< std::is_void< typename c10::invoke_result_t< T &, TensorBase > >::value, unsigned > hook_return_void_t
Definition: TensorBase.h:754
bool is_nested() const
Definition: TensorBase.h:508
std::string toString() const
c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > impl_
Definition: TensorBase.h:837
TensorImpl * unsafeReleaseTensorImpl()
Definition: TensorBase.h:187
c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > unsafeReleaseIntrusivePtr()
Definition: TensorBase.h:194
TensorBase(const TensorBase &)=default
bool is_mkldnn() const
Returns if a Tensor is mkldnn tensor.
Definition: TensorBase.h:461
bool is_quantized() const
Returns if a Tensor has quantized backend.
Definition: TensorBase.h:491
const Storage & storage() const
Definition: TensorBase.h:337
bool is_ipu() const
Returns if a Tensor has IPU backend.
Definition: TensorBase.h:410
std::enable_if_t< std::is_same< typename c10::invoke_result_t< T &, TensorBase >, TensorBase >::value, unsigned > hook_return_var_t
Definition: TensorBase.h:756
void remove_hook(unsigned pos) const
Remove hook at given position.
TensorBase & operator=(const TensorBase &x) &
Definition: TensorBase.h:206
int64_t numel() const
Definition: TensorBase.h:305
bool _is_zerotensor() const
Definition: TensorBase.h:344
int64_t dim() const
Definition: TensorBase.h:115
bool is_sparse_csr() const
Returns is a Tensor has a sparse CSR backend.
Definition: TensorBase.h:455
bool is_xpu() const
Returns if a Tensor has XPU backend.
Definition: TensorBase.h:416
TensorBase(unsafe_borrow_t, const TensorBase &rhs)
Definition: TensorBase.h:87
PackedTensorAccessor32< T, N, PtrTraits > packed_accessor32() const &
Definition: TensorBase.h:579
bool is_alias_of(const at::TensorBase &other) const
Definition: TensorBase.h:340
bool is_complex() const
Definition: TensorBase.h:149
bool is_vulkan() const
Returns if a Tensor is vulkan tensor.
Definition: TensorBase.h:479
hook_return_var_t< T > register_hook(T &&hook) const
void set_data(const TensorBase &new_data) const
bool is_conj() const
Definition: TensorBase.h:352
PackedTensorAccessor64< T, N, PtrTraits > packed_accessor64() const &
Definition: TensorBase.h:590
TensorAccessor< T, N > accessor() &&=delete
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Definition: TensorBase.h:537
DimnameList names() const
Definition: TensorBase.h:248
GenericPackedTensorAccessor< T, N, PtrTraits, index_t > generic_packed_accessor() const &
Definition: TensorBase.h:570
int64_t ndimension() const
Definition: TensorBase.h:251
bool is_view() const
Returns true if this Variable is a view of another Variable.
size_t weak_use_count() const noexcept
Definition: TensorBase.h:225
void * data_ptr() const
Definition: TensorBase.h:543
c10::SymInt sym_storage_offset() const
Definition: TensorBase.h:313
const TensorBase & requires_grad_(bool _requires_grad=true) const
DispatchKeySet key_set() const
Definition: TensorBase.h:328
void _set_neg(bool negative) const
Definition: TensorBase.h:372
bool is_ort() const
Returns if a Tensor is ort tensor.
Definition: TensorBase.h:473
GenericPackedTensorAccessor< T, N > generic_packed_accessor() &&=delete
ScalarType scalar_type() const
Definition: TensorBase.h:331
at::TensorBase variable_data() const
NOTE: var.variable_data() in C++ has the same semantics as tensor.data in Python, which create a new ...
const TensorBase & zero_() const
TensorBase data() const
at::MemoryFormat suggest_memory_format(bool channels_last_strides_exact_match=false) const
Definition: TensorBase.h:263
TensorBase(TensorBase &&)=default
TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, c10::optional< at::MemoryFormat > memory_format=c10::nullopt) const
const TensorBase & fill_(const c10::Scalar &scalar) const
PackedTensorAccessor32< T, N, PtrTraits > packed_accessor32() &&=delete
const NamedTensorMeta * get_named_tensor_meta() const
Returns a Tensor's dimension names data structure.
Definition: TensorBase.h:527
Layout layout() const
Returns a Tensor's layout.
Definition: TensorBase.h:377
bool is_xla() const
Returns if a Tensor has XLA backend.
Definition: TensorBase.h:422
c10::SymIntArrayRef sym_strides() const
Definition: TensorBase.h:237
TensorAccessor< T, N > accessor() const &
Definition: TensorBase.h:556
bool is_meta() const
Returns if a Tensor is a meta tensor.
Definition: TensorBase.h:498
c10::MaybeOwned< TensorBase > expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const &
Should be used if *this can reasonably be expected to be contiguous and performance is important.
Definition: TensorBase.h:938
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const
Definition: TensorBase.h:255
hook_return_void_t< T > register_hook(T &&hook) const
Registers a backward hook.
bool is_lazy() const
Returns if a Tensor has Lazy backend.
Definition: TensorBase.h:432
void reset()
Definition: TensorBase.h:202
c10::MaybeOwned< TensorBase > expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) &&=delete
bool is_neg() const
Definition: TensorBase.h:364
bool is_hpu() const
Returns if a Tensor has HPU backend.
Definition: TensorBase.h:427
size_t itemsize() const
Definition: TensorBase.h:319
int64_t size(int64_t dim) const
Definition: TensorBase.h:173
TensorBase(c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > tensor_impl)
Definition: TensorBase.h:95
bool is_ve() const
Returns if a Tensor has VE backend.
Definition: TensorBase.h:443
Device device() const
Returns a Tensor's device.
Definition: TensorBase.h:387
int64_t _version() const
const TensorBase & _base() const
Returns the Variable that this Variable is a view of.
TensorBase()=default
size_t use_count() const noexcept
Definition: TensorBase.h:222
NamedTensorMeta * get_named_tensor_meta()
Definition: TensorBase.h:531
bool is_mps() const
Returns if a Tensor is mps tensor.
Definition: TensorBase.h:467
const c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > & getIntrusivePtr() const
Definition: TensorBase.h:190
T * data_ptr() const
const TensorBase & set_requires_grad(bool requires_grad) const
Definition: TensorBase.h:691
bool is_leaf() const
All Tensors that have requires_grad() which is false will be leaf Tensors by convention.
caffe2::TypeMeta dtype() const
Returns a Tensor's dtype (TypeMeta).
Definition: TensorBase.h:382
bool is_same(const TensorBase &other) const noexcept
Definition: TensorBase.h:219
bool retains_grad() const
Is true if this Tensor is non-leaf and its :attr:grad is enabled to be populated during :func:backwar...
int64_t get_device() const
Returns a Tensor's device index.
Definition: TensorBase.h:392
void enforce_invariants()
static TensorBase wrap_tensor_impl(c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > tensor_impl)
Definition: TensorBase.h:108
c10::SymInt sym_nbytes() const
Definition: TensorBase.h:297
void _set_zero(bool zero) const
Definition: TensorBase.h:348
int64_t stride(int64_t dim) const
Definition: TensorBase.h:177
IntArrayRef sizes() const
Definition: TensorBase.h:231
bool is_hip() const
Returns if a Tensor has HIP backend.
Definition: TensorBase.h:437
TensorImpl * unsafeGetTensorImpl() const
Definition: TensorBase.h:184
size_t nbytes() const
Definition: TensorBase.h:289
Definition: TensorBody.h:90
c10::SymInt sym_size(int64_t dim) const
Definition: TensorBase.h:161
const Tensor & _fw_grad(uint64_t level) const
This function returns the forward gradient for this Tensor at the given level.
Definition: TensorBody.h:500
int64_t size(at::Dimname dim) const
Definition: TensorBody.h:3350
Definition: Optional.h:549
constexpr bool has_value() const noexcept
Definition: Optional.h:735
TensorBase make_tensor_base(Args &&... args)
Definition: TensorBase.h:869
bool variable_excluded_from_dispatch()
Definition: TensorBase.h:43
std::enable_if_t< std::is_same< T, c10::SymInt >::value > enable_if_symint
Definition: TensorBase.h:949
c10::SymInt size(const TensorBase &t, int64_t dim)
Definition: TensorBase.h:959
c10::SymIntArrayRef sizes(const TensorBase &t)
Definition: TensorBase.h:954
c10::SymIntArrayRef strides(const TensorBase &t)
Definition: TensorBase.h:964
c10::SymInt numel(const TensorBase &t)
Definition: TensorBase.h:969
std::enable_if_t< std::is_same< T, int64_t >::value > enable_if_int
Definition: TensorBase.h:951
Definition: TensorBase.h:34
static DispatchKey legacyExtractDispatchKey(const TensorBase &t)
Definition: TensorBase.h:875
at::Tensor copy(const at::Tensor &self, const at::Tensor &src, bool non_blocking=false)
Definition: Functions.h:1872
at::Tensor zero(const at::Tensor &self)
Definition: Functions.h:23353
c10::MaybeOwned< TensorBase > borrow_from_optional_tensor(const c10::optional< TensorBase > &opt)
Definition: TensorBase.h:931
at::Tensor negative(const at::Tensor &self)
Definition: Functions.h:6469
const TensorBase & get_tensor_base(const Tensor &t)
int64_t get_device(const TensorBase &self)
Definition: TensorBase.h:843
at::Tensor t(const at::Tensor &self)
Definition: Functions.h:7681
Definition: ivalue.h:27
constexpr nullopt_t nullopt
Definition: Optional.h:163
Definition: Device.h:181
Definition: ivalue.h:18
Definition: TensorBase.h:81
static void assignBorrow(borrow_type &lhs, const borrow_type &rhs)
Definition: TensorBase.h:901
static const owned_type * pointerFromBorrow(const borrow_type &borrow)
Definition: TensorBase.h:916
static borrow_type createBorrow(const owned_type &from)
Definition: TensorBase.h:887
static const owned_type & referenceFromBorrow(const borrow_type &borrow)
Definition: TensorBase.h:912
static bool debugBorrowIsValid(const borrow_type &)
Definition: TensorBase.h:920
static void destroyBorrow(borrow_type &toDestroy)
Definition: TensorBase.h:908