4#include <c10/core/Layout.h>
5#include <c10/core/MemoryFormat.h>
6#include <c10/core/ScalarType.h>
7#include <c10/core/ScalarTypeToTypeMeta.h>
8#include <c10/core/Storage.h>
9#include <c10/core/TensorImpl.h>
10#include <c10/core/TensorOptions.h>
11#include <c10/core/UndefinedTensorImpl.h>
12#include <c10/core/WrapDimMinimal.h>
14#include <c10/util/ExclusivelyOwnedTensorTraits.h>
15#include <c10/util/MaybeOwned.h>
17#include <c10/util/intrusive_ptr.h>
19#include <ATen/core/NamedTensor.h>
20#include <ATen/core/QuantizerBase.h>
21#include <c10/core/SymIntArrayRef.h>
22#include <ATen/core/TensorAccessor.h>
28namespace torch {
namespace autograd {
48 return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
88 : impl_(
c10::intrusive_ptr<
at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {}
89 friend MaybeOwnedTraits<TensorBase>;
96 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
97 : impl_(
std::move(tensor_impl)) {
98 if (impl_.get() ==
nullptr) {
99 throw std::runtime_error(
"TensorImpl with nullptr is not supported");
109 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
119 return impl_->storage_offset();
123 if (is_contiguous(memory_format)) {
126 return __dispatch_contiguous(memory_format);
136 c10::MaybeOwned<TensorBase> expect_contiguous(
137 MemoryFormat memory_format=MemoryFormat::Contiguous)
const &;
142 MemoryFormat memory_format=MemoryFormat::Contiguous) && =
delete;
150 return at::isComplexType(this->scalar_type());
154 return at::isFloatingType(this->scalar_type());
158 return at::isSignedType(this->scalar_type());
162 return impl_->sym_size(dim);
166 const auto sizes = this->sym_strides();
167 const auto ndim =
static_cast<int64_t
>(sizes.size());
169 return sizes[c10::maybe_wrap_dim(dim, ndim,
false)];
173 int64_t
size(int64_t dim)
const {
174 return impl_->size(dim);
178 const auto strides = this->strides();
179 const auto ndim =
static_cast<int64_t
>(strides.size());
181 return strides[c10::maybe_wrap_dim(dim, ndim,
false)];
188 return impl_.release();
195 return std::move(impl_);
211 impl_ = std::move(x.impl_);
220 return impl_ == other.impl_;
223 return impl_.use_count();
226 return impl_.weak_use_count();
232 return impl_->sizes();
235 return impl_->sym_sizes();
238 return impl_->sym_strides();
241 return impl_->strides();
245 return impl::get_opt_names(unsafeGetTensorImpl());
249 return impl::get_names(unsafeGetTensorImpl());
255 bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous)
const {
256 return impl_->is_contiguous(memory_format);
260 return impl_->is_non_overlapping_and_dense();
264 bool channels_last_strides_exact_match =
false)
const {
267 if (layout() == at::kStrided) {
268 if (impl_->is_strides_like_channels_last()) {
269 if (!channels_last_strides_exact_match ||
270 get_channels_last_strides_2d(sizes()) == strides()) {
271 return at::MemoryFormat::ChannelsLast;
274 else if (impl_->is_strides_like_channels_last_3d()) {
275 if (!channels_last_strides_exact_match ||
276 get_channels_last_strides_3d(sizes()) == strides()) {
277 return at::MemoryFormat::ChannelsLast3d;
281 return at::MemoryFormat::Contiguous;
291 "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
292 "tensors, add the nbytes of the indices and values. If you want the size of the " \
293 "equivalent dense tensor, multiply numel() by element_size()");
294 return impl_->numel() * impl_->itemsize();
299 "nbytes is not defined for sparse tensors. If you want the size of the constituent " \
300 "tensors, add the nbytes of the indices and values. If you want the size of the " \
301 "equivalent dense tensor, multiply numel() by element_size()");
302 return impl_->sym_numel() * impl_->itemsize();
306 return impl_->numel();
310 return impl_->sym_numel();
314 return impl_->sym_storage_offset();
320 return impl_->itemsize();
325 return static_cast<int64_t
>(impl_->itemsize());
329 return impl_->key_set();
332 return typeMetaToScalarType(impl_->dtype());
335 return defined() && impl_->has_storage();
338 return impl_->storage();
341 return impl_->storage().is_alias_of(other.
storage());
345 return impl_->_is_zerotensor();
349 impl_->_set_zero(
zero);
353 return impl_->is_conj();
361 impl_->_set_conj(conjugate);
365 return impl_->is_neg();
378 return impl_->layout();
383 return impl_->dtype();
388 return impl_->device();
394 return impl_->get_device();
400 return impl_->is_cpu();
406 return impl_->is_cuda();
412 return impl_->is_ipu();
418 return impl_->is_xpu();
423 return impl_->is_xla();
428 return impl_->is_hpu();
433 return impl_->is_lazy();
439 return impl_->is_hip();
445 return impl_->is_ve();
451 return impl_->is_sparse();
457 return impl_->is_sparse_csr();
463 return impl_->is_mkldnn();
469 return impl_->is_mps();
475 return impl_->is_ort();
481 return impl_->is_vulkan();
487 return impl_->is_metal();
493 return impl_->is_quantized();
499 return impl_->is_meta();
504 return impl_->is_inference();
509 return impl_->is_nested();
520 if (!impl_->has_named_tensor_meta()) {
523 return impl::has_names(unsafeGetTensorImpl());
528 return static_cast<NamedTensorMeta*
>(impl_->named_tensor_meta());
532 return static_cast<NamedTensorMeta*
>(impl_->named_tensor_meta());
538 return TensorOptions().dtype(dtype())
544 return this->unsafeGetTensorImpl()->data();
547 template <
typename T>
555 template<
typename T,
size_t N>
557 static_assert(N > 0,
"accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
558 TORCH_CHECK(dim() == N,
"TensorAccessor expected ", N,
" dims but tensor has ", dim());
559 return TensorAccessor<T,N>(data_ptr<T>(),sizes().data(),strides().data());
561 template<
typename T,
size_t N>
569 template<
typename T,
size_t N,
template <
typename U>
class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
571 static_assert(N > 0,
"accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
572 TORCH_CHECK(dim() == N,
"TensorAccessor expected ", N,
" dims but tensor has ", dim());
573 return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(
static_cast<typename PtrTraits<T>::PtrType
>(data_ptr<T>()),sizes().data(),strides().data());
575 template<
typename T,
size_t N,
template <
typename U>
class PtrTraits = DefaultPtrTraits,
typename index_t = int64_t>
578 template<
typename T,
size_t N,
template <
typename U>
class PtrTraits = DefaultPtrTraits>
582 static_cast<int64_t
>(std::numeric_limits<int32_t>::max()),
583 "numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64");
584 return generic_packed_accessor<T,N,PtrTraits,int32_t>();
586 template<
typename T,
size_t N,
template <
typename U>
class PtrTraits = DefaultPtrTraits>
589 template<
typename T,
size_t N,
template <
typename U>
class PtrTraits = DefaultPtrTraits>
591 return generic_packed_accessor<T,N,PtrTraits,int64_t>();
593 template<
typename T,
size_t N,
template <
typename U>
class PtrTraits = DefaultPtrTraits>
696 return impl_->requires_grad();
704 return impl_->
_fw_grad(level, *
this);
712 impl_->_set_fw_grad(new_grad, *
this, level, is_inplace_op);
748 const std::shared_ptr<torch::autograd::Node>&
grad_fn()
const;
753 template <
typename T>
754 using hook_return_void_t = std::enable_if_t<std::is_void<typename c10::invoke_result_t<T&, TensorBase>>::value,
unsigned>;
755 template <
typename T>
788 template <
typename T>
790 template <
typename T>
833 const std::string&
name()
const;
837 c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>
impl_;
840 TensorBase __dispatch_contiguous(c10::MemoryFormat)
const;
851 static_assert(std::is_same<
decltype(hook(
TensorBase())),
void>::value,
852 "Expected hook to return void");
853 return _register_hook([fn=std::forward<T>(hook)](
const TensorBase& grad) {
861 return _register_hook(std::forward<T>(hook));
868template <
typename T,
typename... Args>
870 return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
883struct MaybeOwnedTraits<
at::TensorBase> {
926struct ExclusivelyOwnedTraits<
at::TensorBase> :
public c10::ExclusivelyOwnedTensorTraits<at::TensorBase> {};
934 ? c10::MaybeOwned<TensorBase>::borrowed(*opt)
935 : c10::MaybeOwned<TensorBase>::owned(c10::in_place);
939 if (is_contiguous(memory_format)) {
940 return c10::MaybeOwned<TensorBase>::borrowed(*
this);
942 return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
953template <
typename T,
typename = enable_if_sym
int<T>>
955template <
typename T,
typename = enable_if_
int<T>>
958template <
typename T,
typename = enable_if_sym
int<T>>
960template <
typename T,
typename = enable_if_
int<T>>
963template <
typename T,
typename = enable_if_sym
int<T>>
965template <
typename T,
typename = enable_if_
int<T>>
968template <
typename T,
typename = enable_if_sym
int<T>>
970template <
typename T,
typename = enable_if_
int<T>>
#define TORCH_CHECK(cond,...)
Definition: Exception.h:505
Definition: TensorBase.h:79
IntArrayRef strides() const
Definition: TensorBase.h:240
void retain_grad() const
Enables this Tensor to have their :attr:grad populated during :func:backward.
bool is_sparse() const
Returns if a Tensor has sparse backend.
Definition: TensorBase.h:449
bool is_inference() const
Returns if a Tensor is an inference tensor.
Definition: TensorBase.h:503
const std::string & name() const
bool is_signed() const
Definition: TensorBase.h:157
void _set_conj(bool conjugate) const
Definition: TensorBase.h:360
bool defined() const
Definition: TensorBase.h:198
c10::SymIntArrayRef sym_sizes() const
Definition: TensorBase.h:234
int64_t output_nr() const
c10::SymInt sym_size(int64_t dim) const
Definition: TensorBase.h:161
int64_t element_size() const
Definition: TensorBase.h:324
bool requires_grad() const
Definition: TensorBase.h:695
c10::optional< DimnameList > opt_names() const
Definition: TensorBase.h:244
bool is_cuda() const
Returns if a Tensor has CUDA backend.
Definition: TensorBase.h:404
c10::SymInt sym_stride(int64_t dim) const
Definition: TensorBase.h:165
int64_t storage_offset() const
Definition: TensorBase.h:118
const Tensor & _fw_grad(uint64_t level) const
This function returns the forward gradient for this Tensor at the given level.
Definition: TensorBase.h:703
TensorBase & operator=(TensorBase &&) &&noexcept=delete
bool is_metal() const
Returns if a Tensor is metal tensor.
Definition: TensorBase.h:485
PackedTensorAccessor64< T, N, PtrTraits > packed_accessor64() &&=delete
at::TensorBase tensor_data() const
NOTE: This is similar to the legacy .data() function on Variable, and is intended to be used from fun...
unsigned _register_hook(std::function< TensorBase(const TensorBase &)> hook) const
bool has_storage() const
Definition: TensorBase.h:334
bool is_non_overlapping_and_dense() const
Definition: TensorBase.h:259
bool is_cpu() const
Returns if a Tensor has CPU backend.
Definition: TensorBase.h:398
QuantizerPtr quantizer() const
If a tensor is a quantized tensor, returns its quantizer TODO: it's not in native_functions....
bool has_names() const
Returns if a Tensor has any dimension names.
Definition: TensorBase.h:517
void _set_fw_grad(const TensorBase &new_grad, uint64_t level, bool is_inplace_op) const
This function can be used to set the value of the forward grad.
Definition: TensorBase.h:711
bool is_floating_point() const
Definition: TensorBase.h:153
TensorBase & operator=(TensorBase &&x) &noexcept
Definition: TensorBase.h:210
const std::shared_ptr< torch::autograd::Node > & grad_fn() const
Gets the gradient function of the Variable.
TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const
Definition: TensorBase.h:122
TensorBase & operator=(const TensorBase &) &&=delete
c10::SymInt sym_numel() const
Definition: TensorBase.h:309
std::enable_if_t< std::is_void< typename c10::invoke_result_t< T &, TensorBase > >::value, unsigned > hook_return_void_t
Definition: TensorBase.h:754
bool is_nested() const
Definition: TensorBase.h:508
std::string toString() const
c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > impl_
Definition: TensorBase.h:837
TensorImpl * unsafeReleaseTensorImpl()
Definition: TensorBase.h:187
c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > unsafeReleaseIntrusivePtr()
Definition: TensorBase.h:194
TensorBase(const TensorBase &)=default
bool is_mkldnn() const
Returns if a Tensor is mkldnn tensor.
Definition: TensorBase.h:461
bool is_quantized() const
Returns if a Tensor has quantized backend.
Definition: TensorBase.h:491
const Storage & storage() const
Definition: TensorBase.h:337
bool is_ipu() const
Returns if a Tensor has IPU backend.
Definition: TensorBase.h:410
std::enable_if_t< std::is_same< typename c10::invoke_result_t< T &, TensorBase >, TensorBase >::value, unsigned > hook_return_var_t
Definition: TensorBase.h:756
void remove_hook(unsigned pos) const
Remove hook at given position.
TensorBase & operator=(const TensorBase &x) &
Definition: TensorBase.h:206
int64_t numel() const
Definition: TensorBase.h:305
bool _is_zerotensor() const
Definition: TensorBase.h:344
int64_t dim() const
Definition: TensorBase.h:115
bool is_sparse_csr() const
Returns is a Tensor has a sparse CSR backend.
Definition: TensorBase.h:455
bool is_xpu() const
Returns if a Tensor has XPU backend.
Definition: TensorBase.h:416
TensorBase(unsafe_borrow_t, const TensorBase &rhs)
Definition: TensorBase.h:87
PackedTensorAccessor32< T, N, PtrTraits > packed_accessor32() const &
Definition: TensorBase.h:579
bool is_alias_of(const at::TensorBase &other) const
Definition: TensorBase.h:340
bool is_complex() const
Definition: TensorBase.h:149
bool is_vulkan() const
Returns if a Tensor is vulkan tensor.
Definition: TensorBase.h:479
hook_return_var_t< T > register_hook(T &&hook) const
void set_data(const TensorBase &new_data) const
bool is_conj() const
Definition: TensorBase.h:352
PackedTensorAccessor64< T, N, PtrTraits > packed_accessor64() const &
Definition: TensorBase.h:590
TensorAccessor< T, N > accessor() &&=delete
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Definition: TensorBase.h:537
DimnameList names() const
Definition: TensorBase.h:248
GenericPackedTensorAccessor< T, N, PtrTraits, index_t > generic_packed_accessor() const &
Definition: TensorBase.h:570
int64_t ndimension() const
Definition: TensorBase.h:251
bool is_view() const
Returns true if this Variable is a view of another Variable.
size_t weak_use_count() const noexcept
Definition: TensorBase.h:225
void * data_ptr() const
Definition: TensorBase.h:543
c10::SymInt sym_storage_offset() const
Definition: TensorBase.h:313
const TensorBase & requires_grad_(bool _requires_grad=true) const
DispatchKeySet key_set() const
Definition: TensorBase.h:328
void _set_neg(bool negative) const
Definition: TensorBase.h:372
bool is_ort() const
Returns if a Tensor is ort tensor.
Definition: TensorBase.h:473
GenericPackedTensorAccessor< T, N > generic_packed_accessor() &&=delete
ScalarType scalar_type() const
Definition: TensorBase.h:331
at::TensorBase variable_data() const
NOTE: var.variable_data() in C++ has the same semantics as tensor.data in Python, which create a new ...
const TensorBase & zero_() const
at::MemoryFormat suggest_memory_format(bool channels_last_strides_exact_match=false) const
Definition: TensorBase.h:263
TensorBase(TensorBase &&)=default
TensorBase to(at::TensorOptions options={}, bool non_blocking=false, bool copy=false, c10::optional< at::MemoryFormat > memory_format=c10::nullopt) const
const TensorBase & fill_(const c10::Scalar &scalar) const
PackedTensorAccessor32< T, N, PtrTraits > packed_accessor32() &&=delete
const NamedTensorMeta * get_named_tensor_meta() const
Returns a Tensor's dimension names data structure.
Definition: TensorBase.h:527
Layout layout() const
Returns a Tensor's layout.
Definition: TensorBase.h:377
bool is_xla() const
Returns if a Tensor has XLA backend.
Definition: TensorBase.h:422
c10::SymIntArrayRef sym_strides() const
Definition: TensorBase.h:237
TensorAccessor< T, N > accessor() const &
Definition: TensorBase.h:556
bool is_meta() const
Returns if a Tensor is a meta tensor.
Definition: TensorBase.h:498
c10::MaybeOwned< TensorBase > expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const &
Should be used if *this can reasonably be expected to be contiguous and performance is important.
Definition: TensorBase.h:938
bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const
Definition: TensorBase.h:255
hook_return_void_t< T > register_hook(T &&hook) const
Registers a backward hook.
bool is_lazy() const
Returns if a Tensor has Lazy backend.
Definition: TensorBase.h:432
void reset()
Definition: TensorBase.h:202
c10::MaybeOwned< TensorBase > expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) &&=delete
bool is_neg() const
Definition: TensorBase.h:364
bool is_hpu() const
Returns if a Tensor has HPU backend.
Definition: TensorBase.h:427
size_t itemsize() const
Definition: TensorBase.h:319
int64_t size(int64_t dim) const
Definition: TensorBase.h:173
TensorBase(c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > tensor_impl)
Definition: TensorBase.h:95
bool is_ve() const
Returns if a Tensor has VE backend.
Definition: TensorBase.h:443
Device device() const
Returns a Tensor's device.
Definition: TensorBase.h:387
const TensorBase & _base() const
Returns the Variable that this Variable is a view of.
size_t use_count() const noexcept
Definition: TensorBase.h:222
NamedTensorMeta * get_named_tensor_meta()
Definition: TensorBase.h:531
bool is_mps() const
Returns if a Tensor is mps tensor.
Definition: TensorBase.h:467
const c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > & getIntrusivePtr() const
Definition: TensorBase.h:190
const TensorBase & set_requires_grad(bool requires_grad) const
Definition: TensorBase.h:691
bool is_leaf() const
All Tensors that have requires_grad() which is false will be leaf Tensors by convention.
caffe2::TypeMeta dtype() const
Returns a Tensor's dtype (TypeMeta).
Definition: TensorBase.h:382
bool is_same(const TensorBase &other) const noexcept
Definition: TensorBase.h:219
bool retains_grad() const
Is true if this Tensor is non-leaf and its :attr:grad is enabled to be populated during :func:backwar...
int64_t get_device() const
Returns a Tensor's device index.
Definition: TensorBase.h:392
void enforce_invariants()
static TensorBase wrap_tensor_impl(c10::intrusive_ptr< TensorImpl, UndefinedTensorImpl > tensor_impl)
Definition: TensorBase.h:108
c10::SymInt sym_nbytes() const
Definition: TensorBase.h:297
void _set_zero(bool zero) const
Definition: TensorBase.h:348
int64_t stride(int64_t dim) const
Definition: TensorBase.h:177
IntArrayRef sizes() const
Definition: TensorBase.h:231
bool is_hip() const
Returns if a Tensor has HIP backend.
Definition: TensorBase.h:437
TensorImpl * unsafeGetTensorImpl() const
Definition: TensorBase.h:184
size_t nbytes() const
Definition: TensorBase.h:289
Definition: TensorBody.h:90
c10::SymInt sym_size(int64_t dim) const
Definition: TensorBase.h:161
const Tensor & _fw_grad(uint64_t level) const
This function returns the forward gradient for this Tensor at the given level.
Definition: TensorBody.h:500
int64_t size(at::Dimname dim) const
Definition: TensorBody.h:3350
Definition: Optional.h:549
constexpr bool has_value() const noexcept
Definition: Optional.h:735
TensorBase make_tensor_base(Args &&... args)
Definition: TensorBase.h:869
bool variable_excluded_from_dispatch()
Definition: TensorBase.h:43
std::enable_if_t< std::is_same< T, c10::SymInt >::value > enable_if_symint
Definition: TensorBase.h:949
c10::SymInt size(const TensorBase &t, int64_t dim)
Definition: TensorBase.h:959
c10::SymIntArrayRef sizes(const TensorBase &t)
Definition: TensorBase.h:954
c10::SymIntArrayRef strides(const TensorBase &t)
Definition: TensorBase.h:964
c10::SymInt numel(const TensorBase &t)
Definition: TensorBase.h:969
std::enable_if_t< std::is_same< T, int64_t >::value > enable_if_int
Definition: TensorBase.h:951
Definition: TensorBase.h:34
static DispatchKey legacyExtractDispatchKey(const TensorBase &t)
Definition: TensorBase.h:875
at::Tensor copy(const at::Tensor &self, const at::Tensor &src, bool non_blocking=false)
Definition: Functions.h:1872
at::Tensor zero(const at::Tensor &self)
Definition: Functions.h:23353
c10::MaybeOwned< TensorBase > borrow_from_optional_tensor(const c10::optional< TensorBase > &opt)
Definition: TensorBase.h:931
at::Tensor negative(const at::Tensor &self)
Definition: Functions.h:6469
const TensorBase & get_tensor_base(const Tensor &t)
int64_t get_device(const TensorBase &self)
Definition: TensorBase.h:843
at::Tensor t(const at::Tensor &self)
Definition: Functions.h:7681
constexpr nullopt_t nullopt
Definition: Optional.h:163
Definition: TensorBase.h:81
unsafe_borrow_t()=default
static void assignBorrow(borrow_type &lhs, const borrow_type &rhs)
Definition: TensorBase.h:901
static const owned_type * pointerFromBorrow(const borrow_type &borrow)
Definition: TensorBase.h:916
static borrow_type createBorrow(const owned_type &from)
Definition: TensorBase.h:887
static const owned_type & referenceFromBorrow(const borrow_type &borrow)
Definition: TensorBase.h:912
static bool debugBorrowIsValid(const borrow_type &)
Definition: TensorBase.h:920
static void destroyBorrow(borrow_type &toDestroy)
Definition: TensorBase.h:908