3#include <c10/util/irange.h>
4#include <ATen/core/IListRef.h>
18 int64_t first_dims = first.
dim();
19 int64_t second_dims = second.
dim();
20 TORCH_CHECK(first_dims == second_dims,
"Tensors must have same number of dimensions: got ",
21 first_dims,
" and ", second_dims);
22 for (
const auto dim : c10::irange(first_dims)) {
23 if (dim == dimension) {
26 int64_t first_dim_size = first.
sizes()[dim];
27 int64_t second_dim_size = second.
sizes()[dim];
28 TORCH_CHECK(first_dim_size == second_dim_size,
"Sizes of tensors must match except in dimension ",
29 dimension,
". Expected size ",
static_cast<long long>(first_dim_size),
" but got size ",
static_cast<long long>(second_dim_size),
" for tensor number ",
index,
" in the list.");
35 for(
const Tensor&
t : tensors) {
37 "zero-dimensional tensor (at position ", i,
") cannot be concatenated");
43 TORCH_CHECK(self.
dim() != 0,
"split expects at least a 1-dimensional tensor");
44 TORCH_CHECK(split_size >= 0,
"split expects split_size be non-negative, but got split_size=", split_size);
45 int64_t dim_size = self.
size(dim);
47 "split_size can only be 0 if dimension size is 0, "
48 "but got dimension size of ", dim_size);
50 int64_t num_splits = 1;
51 if (split_size != 0) {
54 num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
#define TORCH_CHECK(cond,...)
Definition: Exception.h:505
int64_t numel() const
Definition: TensorBase.h:305
int64_t dim() const
Definition: TensorBase.h:115
IntArrayRef sizes() const
Definition: TensorBase.h:231
Definition: TensorBody.h:90
int64_t size(at::Dimname dim) const
Definition: TensorBody.h:3350
bool cat_should_skip_tensor(const Tensor &t)
Definition: TensorShape.h:11
at::Tensor clone_preserve_strides(const at::Tensor &self)
void check_cat_no_zero_dim(const MaterializedITensorListRef &tensors)
Definition: TensorShape.h:33
void check_cat_shape_except_dim(const Tensor &first, const Tensor &second, int64_t dimension, int64_t index)
Definition: TensorShape.h:17
int64_t get_num_splits(const Tensor &self, int64_t split_size, int64_t dim)
Definition: TensorShape.h:42
Definition: TensorBase.h:34
at::Tensor index(const at::Tensor &self, const c10::List< c10::optional< at::Tensor > > &indices)
Definition: Functions.h:3622
at::Tensor t(const at::Tensor &self)
Definition: Functions.h:7681