Skip to content

Commit

Permalink
Update on "Support different NSE in batches of CSR and CSC tensors"
Browse files Browse the repository at this point in the history
This PR enables batched CSR/CSC tensors that batches may have different NSE counts.

For instance, with the current master we have
```python
>>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]])
>>> a.to_sparse_csr()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expect the same number of specified elements per batch.
```
because the NSE of the first and second batches are different, 4 and 2, respectively.

This PR implements a strided-to-sparse-CSR/CSC conversion algorithm that supports CSR/CSC batches with different NSE counts. For instance:
```python
>>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]])
>>> b = a.to_sparse_csr()
>>> b
tensor(crow_indices=tensor([[0, 2, 4],
                            [0, 1, 2]]),
       col_indices=tensor([[0, 1, 0, 1],
                           [1, 0, 0, 0]]),
       values=tensor([[ 1,  2,  3,  4],
                      [12, 21,  0,  0]]), size=(2, 2, 2), nnz=4,
       layout=torch.sparse_csr)
>>> b[0]
tensor(crow_indices=tensor([0, 2, 4]),
       col_indices=tensor([0, 1, 0, 1]),
       values=tensor([1, 2, 3, 4]), size=(2, 2), nnz=4,
       layout=torch.sparse_csr)
>>> b[1]
tensor(crow_indices=tensor([0, 1, 2]),
       col_indices=tensor([1, 0]),
       values=tensor([12, 21]), size=(2, 2), nnz=2, layout=torch.sparse_csr)
```
that is, if the NSE of a batch is smaller than the maximum NSE over all batches, the corresponding rows in `col_indices`/`values` are padded with zeros as placeholders. Algorithms on batched CSR/CSC tensors must not access the padded parts of these tensors, that is, the algorithms should use the last element of the corresponding `crow_indices` row as the NSE value rather than the value of `.values().shape[0]` that holds the maximum NSE over all batches.

Performance-wise, the strided-to-sparse-CSR/CSC conversion algorithms in master and in this PR, are roughly equivalent:
```python
# master branch:
n [2]: a = torch.rand(10, 10, 1000, 1000)

In [3]: a = torch.where(a==0, 0.1, a)  # required for master, optional for the PR

In [4]: %timeit a.to_sparse_csr()
2.25 s ± 9.84 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: a_cuda = a.cuda()

In [6]: %timeit a_cuda.to_sparse_csr()
55.2 ms ± 6.95 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
```python
# this PR
In [2]: a = torch.rand(10, 10, 1000, 1000)

In [3]: a = torch.where(a==0, 0.1, a)  # required for master, optional for the PR

In [4]: %timeit a.to_sparse_csr()
2.13 s ± 7.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: a_cuda = a.cuda()

In [6]: %timeit a_cuda.to_sparse_csr()
54.3 ms ± 20.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
The performance of the PR is only slightly better than the master branch.
 
A strided-to-sparse-BSR/BSC conversion with variable NSE support will be implemented as a follow-up.




[ghstack-poisoned]
  • Loading branch information
pearu committed Sep 12, 2022
2 parents 27a47f1 + 7696fd8 commit 16bad3d
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,6 @@ Tensor dense_to_batched_sparse_compressed_nonblock(const Tensor& self, const Lay
Device index_device = self.device();
auto n_batch_dim = self.dim() - 2;
TORCH_INTERNAL_ASSERT(n_batch_dim > 0);
auto nvalues = self.numel();
int compressed_dim_size, plain_dim_size;
std::tie(compressed_dim_size, plain_dim_size) = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(target_layout, "dense_to_batched_sparse_compressed_nonblock",
[&] { return std::make_tuple(self.size(-2), self.size(-1)); },
Expand Down Expand Up @@ -831,17 +830,21 @@ Tensor dense_to_batched_sparse_compressed_nonblock(const Tensor& self, const Lay
Tensor compressed_indices = batch_compressed_indices.reshape(compressed_indices_size);

Tensor batch_flat_indices = at::zeros({nbatches, max_nse}, flat_compressed_indices.options());
at::parallel_for(0, nbatches, 0, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
Tensor tmp = non_zero_mask[i].nonzero().flatten();
batch_flat_indices.select(0, i).narrow(0, 0, tmp.numel()).copy_(tmp);
}
});
if (nvalues > 0) {
batch_flat_indices.add_(at::native::arange(0, nvalues, compressed_dim_size * plain_dim_size, index_dtype, kStrided, index_device)
.reshape({nbatches, 1}));
}

Tensor non_zero_indices = non_zero_mask.flatten().nonzero().flatten();
Tensor nse_cpu = nse.cpu();
AT_DISPATCH_INTEGRAL_TYPES(nse_cpu.scalar_type(), "dense_to_batched_sparse_compressed_nonblock",
[&]() {
scalar_t cnse_im1 = 0, cnse_i = 0;
scalar_t* nse_ptr = nse_cpu.data_ptr<scalar_t>();
for (const auto i : c10::irange(0, nbatches)) {
const auto nse_i = nse_ptr[i];
cnse_i += nse_i;
batch_flat_indices.select(0, i)
.narrow(0, 0, nse_i)
.copy_(non_zero_indices.slice(0, cnse_im1, cnse_i, 1));
cnse_im1 = cnse_i;
}
});
Tensor flat_ordering = batch_flat_indices.flatten();

// plain_indices and values have the same size because dense
Expand Down

0 comments on commit 16bad3d

Please sign in to comment.