PyTorch
Loading...
Searching...
No Matches
Tensor.h
Go to the documentation of this file.
1#pragma once
2
5
6namespace at {
7class TORCH_API OptionalTensorRef {
8 public:
9 OptionalTensorRef() = default;
10
12 ref_.unsafeReleaseTensorImpl();
13 }
14
16 : ref_(Tensor::unsafe_borrow_t{}, src) {
18 }
19
21 : ref_(Tensor::unsafe_borrow_t{}, rhs.ref_) {}
22
24 std::swap(ref_, rhs.ref_);
25 return *this;
26 }
27
28 bool has_value() const {
29 return ref_.defined();
30 }
31
32 const Tensor& getTensorRef() const & {
33 return ref_;
34 }
35
36 const Tensor& operator*() const & {
37 return ref_;
38 }
39
40 const Tensor* operator->() const & {
41 return &ref_;
42 }
43
44 operator bool() const {
45 return ref_.defined();
46 }
47
48 private:
49 Tensor ref_;
50};
51
52template <typename T>
54 // Return the grad argument in case of a hook with void return type to have an
55 // std::function with Tensor return type
56 static_assert(std::is_same<decltype(hook(Tensor())), void>::value,
57 "Expected hook to return void");
58 return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
59 OptionalTensorRef grad(grad_base);
60 fn(*grad);
61 return Tensor();
62 });
63}
64
65template <typename T>
67 return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
68 OptionalTensorRef grad(grad_base);
69 Tensor ret = fn(*grad);
70 return TensorBase(std::move(ret));
71 });
72}
73
74} // namespace at
#define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...)
Definition: Exception.h:534
Definition: Tensor.h:7
~OptionalTensorRef()
Definition: Tensor.h:11
const Tensor & getTensorRef() const &
Definition: Tensor.h:32
bool has_value() const
Definition: Tensor.h:28
OptionalTensorRef(const OptionalTensorRef &rhs)
Definition: Tensor.h:20
OptionalTensorRef & operator=(OptionalTensorRef rhs)
Definition: Tensor.h:23
const Tensor * operator->() const &
Definition: Tensor.h:40
const Tensor & operator*() const &
Definition: Tensor.h:36
OptionalTensorRef()=default
OptionalTensorRef(const TensorBase &src)
Definition: Tensor.h:15
Definition: TensorBase.h:79
bool defined() const
Definition: TensorBase.h:198
Definition: TensorBody.h:90
hook_return_void_t< T > register_hook(T &&hook) const
Registers a backward hook.
std::enable_if_t< std::is_same< typename c10::invoke_result_t< T &, Tensor >, Tensor >::value, unsigned > hook_return_var_t
Definition: TensorBody.h:1391
std::enable_if_t< std::is_void< typename c10::invoke_result_t< T &, Tensor > >::value, unsigned > hook_return_void_t
Definition: TensorBody.h:1389
Definition: TensorBase.h:34