Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "Support different NSE in batches of CSR and CSC tensors"
This PR enables batched CSR/CSC tensors that batches may have different NSE counts. For instance, with the current master we have ```python >>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]]) >>> a.to_sparse_csr() Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Expect the same number of specified elements per batch. ``` because the NSE of the first and second batches are different, 4 and 2, respectively. This PR implements a strided-to-sparse-CSR/CSC conversion algorithm that supports CSR/CSC batches with different NSE counts. For instance: ```python >>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]]) >>> b = a.to_sparse_csr() >>> b tensor(crow_indices=tensor([[0, 2, 4], [0, 1, 2]]), col_indices=tensor([[0, 1, 0, 1], [1, 0, 0, 0]]), values=tensor([[ 1, 2, 3, 4], [12, 21, 0, 0]]), size=(2, 2, 2), nnz=4, layout=torch.sparse_csr) >>> b[0] tensor(crow_indices=tensor([0, 2, 4]), col_indices=tensor([0, 1, 0, 1]), values=tensor([1, 2, 3, 4]), size=(2, 2), nnz=4, layout=torch.sparse_csr) >>> b[1] tensor(crow_indices=tensor([0, 1, 2]), col_indices=tensor([1, 0]), values=tensor([12, 21]), size=(2, 2), nnz=2, layout=torch.sparse_csr) ``` that is, if the NSE of a batch is smaller than the maximum NSE over all batches, the corresponding rows in `col_indices`/`values` are padded with zeros as placeholders. Algorithms on batched CSR/CSC tensors must not access the padded parts of these tensors, that is, the algorithms should use the last element of the corresponding `crow_indices` row as the NSE value rather than the value of `.values().shape[0]` that holds the maximum NSE over all batches. Performance-wise, the strided-to-sparse-CSR/CSC conversion algorithms in master and in this PR, are roughly equivalent: ```python # master branch: n [2]: a = torch.rand(10, 10, 1000, 1000) In [3]: a = torch.where(a==0, 0.1, a) # required for master, optional for the PR In [4]: %timeit a.to_sparse_csr() 2.25 s ± 9.84 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [5]: a_cuda = a.cuda() In [6]: %timeit a_cuda.to_sparse_csr() 55.2 ms ± 6.95 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` ```python # this PR In [2]: a = torch.rand(10, 10, 1000, 1000) In [3]: a = torch.where(a==0, 0.1, a) # required for master, optional for the PR In [4]: %timeit a.to_sparse_csr() 2.13 s ± 7.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [5]: a_cuda = a.cuda() In [6]: %timeit a_cuda.to_sparse_csr() 54.3 ms ± 20.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` The performance of the PR is only slightly better than the master branch. A strided-to-sparse-BSR/BSC conversion with variable NSE support will be implemented as a follow-up. [ghstack-poisoned]
- Loading branch information