Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different NSE in batches of CSR and CSC tensors #84843

Draft
wants to merge 5 commits into
base: gh/pearu/56/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Prev Previous commit
Update on "Support different NSE in batches of CSR and CSC tensors"
This PR enables batched CSR/CSC tensors that batches may have different NSE counts.

For instance, with the current master we have
```python
>>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]])
>>> a.to_sparse_csr()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expect the same number of specified elements per batch.
```
because the NSE of the first and second batches are different, 4 and 2, respectively.

This PR implements a strided-to-sparse-CSR/CSC conversion algorithm that supports CSR/CSC batches with different NSE counts. For instance:
```python
>>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]])
>>> b = a.to_sparse_csr()
>>> b
tensor(crow_indices=tensor([[0, 2, 4],
                            [0, 1, 2]]),
       col_indices=tensor([[0, 1, 0, 1],
                           [1, 0, 0, 0]]),
       values=tensor([[ 1,  2,  3,  4],
                      [12, 21,  0,  0]]), size=(2, 2, 2), nnz=4,
       layout=torch.sparse_csr)
>>> b[0]
tensor(crow_indices=tensor([0, 2, 4]),
       col_indices=tensor([0, 1, 0, 1]),
       values=tensor([1, 2, 3, 4]), size=(2, 2), nnz=4,
       layout=torch.sparse_csr)
>>> b[1]
tensor(crow_indices=tensor([0, 1, 2]),
       col_indices=tensor([1, 0]),
       values=tensor([12, 21]), size=(2, 2), nnz=2, layout=torch.sparse_csr)
```
that is, if the NSE of a batch is smaller than the maximum NSE over all batches, the corresponding rows in `col_indices`/`values` are padded with zeros as placeholders. Algorithms on batched CSR/CSC tensors must not access the padded parts of these tensors, that is, the algorithms should use the last element of the corresponding `crow_indices` row as the NSE value rather than the value of `.values().shape[0]` that holds the maximum NSE over all batches.

Performance-wise, the strided-to-sparse-CSR/CSC conversion algorithms in master and in this PR, are roughly equivalent:
```python
# master branch:
n [2]: a = torch.rand(10, 10, 1000, 1000)

In [3]: a = torch.where(a==0, 0.1, a)  # required for master, optional for the PR

In [4]: %timeit a.to_sparse_csr()
2.25 s ± 9.84 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: a_cuda = a.cuda()

In [6]: %timeit a_cuda.to_sparse_csr()
55.2 ms ± 6.95 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
```python
# this PR
In [2]: a = torch.rand(10, 10, 1000, 1000)

In [3]: a = torch.where(a==0, 0.1, a)  # required for master, optional for the PR

In [4]: %timeit a.to_sparse_csr()
2.12 s ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: a_cuda = a.cuda()

In [6]: %timeit a_cuda.to_sparse_csr(); torch.cuda.synchronize()
47.2 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
The performance of `to_sparse_csr()` on CUDA tensors increased by 15% with this PR.
 
A strided-to-sparse-BSR/BSC conversion with variable NSE support will be implemented as a follow-up.




[ghstack-poisoned]
  • Loading branch information
pearu committed Sep 19, 2022
commit 18f4bfcc27a42f29e8d2940feba800ddf440de7b
9 changes: 0 additions & 9 deletions test/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,6 @@ def sample_inputs_generator():
other = sample_input.args[0]
nse = (other != 0).sum((-2, -1))
for other_layout in layouts:
if nse.max() != nse.min() and other_layout in {torch.sparse_csr, torch.sparse_csc,
torch.sparse_bsr, torch.sparse_bsc}:
# TODO: sparse compressed tensors
# require equal NSE in batches (see PR
# 84834 that removes this
# restriction). After PR 84834 (or
# equivalent) lands, remove this
# if-block
continue
if other.layout != other_layout:
yield SampleInput(sample_input.input.clone(),
args=(torch.sparse._to_layout(other, other_layout),),
Expand Down
27 changes: 11 additions & 16 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,22 +1046,17 @@ def template_reference_inputs_bmm_sparse(self, input_layout, device, dtype, requ
args=(b_sparse.detach().requires_grad_(requires_grad),),
kwargs=sample_input.kwargs)

if input_layout not in {torch.sparse_csr, torch.sparse_csc,
torch.sparse_bsr, torch.sparse_bsc}:
# TODO: sparse compressed tensors require equal NSE in batches
# (see PR 84834 that removes this restriction). After PR 84834
# (or equivalent) lands, enable this if-block for all layouts.
a = make_tensor((2 * S, S + 2, S + 3), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
a[0].fill_(0)
a[2].fill_(0)
a[3, 1].fill_(0)
a[4, :, 1].fill_(0)
b = make_tensor((2 * S, S + 3, S + 1), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
b[1].fill_(0)
b[2].fill_(0)
b[3, 2].fill_(0)
b[4, :, 2].fill_(0)
yield SampleInput(torch.sparse._to_layout(a, input_layout), args=(b,))
a = make_tensor((2 * S, S + 2, S + 3), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
a[0].fill_(0)
a[2].fill_(0)
a[3, 1].fill_(0)
a[4, :, 1].fill_(0)
b = make_tensor((2 * S, S + 3, S + 1), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
b[1].fill_(0)
b[2].fill_(0)
b[3, 2].fill_(0)
b[4, :, 2].fill_(0)
yield SampleInput(torch.sparse._to_layout(a, input_layout), args=(b,))


def reference_inputs_sparse_coo_bmm(self, device, dtype, requires_grad, **kwargs):
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.