-
Notifications
You must be signed in to change notification settings - Fork 21.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "Support different NSE in batches of CSR and CSC tensors"
This PR enables batched CSR/CSC tensors that batches may have different NSE counts. For instance, with the current master we have ```python >>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]]) >>> a.to_sparse_csr() Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Expect the same number of specified elements per batch. ``` because the NSE of the first and second batches are different, 4 and 2, respectively. This PR implements a strided-to-sparse-CSR/CSC conversion algorithm that supports CSR/CSC batches with different NSE counts. For instance: ```python >>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]]) >>> b = a.to_sparse_csr() >>> b tensor(crow_indices=tensor([[0, 2, 4], [0, 1, 2]]), col_indices=tensor([[0, 1, 0, 1], [1, 0, 0, 0]]), values=tensor([[ 1, 2, 3, 4], [12, 21, 0, 0]]), size=(2, 2, 2), nnz=4, layout=torch.sparse_csr) >>> b[0] tensor(crow_indices=tensor([0, 2, 4]), col_indices=tensor([0, 1, 0, 1]), values=tensor([1, 2, 3, 4]), size=(2, 2), nnz=4, layout=torch.sparse_csr) >>> b[1] tensor(crow_indices=tensor([0, 1, 2]), col_indices=tensor([1, 0]), values=tensor([12, 21]), size=(2, 2), nnz=2, layout=torch.sparse_csr) ``` that is, if the NSE of a batch is smaller than the maximum NSE over all batches, the corresponding rows in `col_indices`/`values` are padded with zeros as placeholders. Algorithms on batched CSR/CSC tensors must not access the padded parts of these tensors, that is, the algorithms should use the last element of the corresponding `crow_indices` row as the NSE value rather than the value of `.values().shape[0]` that holds the maximum NSE over all batches. Performance-wise, the strided-to-sparse-CSR/CSC conversion algorithms in master and in this PR, are roughly equivalent: ```python # master branch: n [2]: a = torch.rand(10, 10, 1000, 1000) In [3]: a = torch.where(a==0, 0.1, a) # required for master, optional for the PR In [4]: %timeit a.to_sparse_csr() 2.25 s ± 9.84 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [5]: a_cuda = a.cuda() In [6]: %timeit a_cuda.to_sparse_csr() 55.2 ms ± 6.95 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` ```python # this PR In [2]: a = torch.rand(10, 10, 1000, 1000) In [3]: a = torch.where(a==0, 0.1, a) # required for master, optional for the PR In [4]: %timeit a.to_sparse_csr() 2.12 s ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [5]: a_cuda = a.cuda() In [6]: %timeit a_cuda.to_sparse_csr(); torch.cuda.synchronize() 47.2 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) ``` The performance of `to_sparse_csr()` on CUDA tensors increased by 15% with this PR. A strided-to-sparse-BSR/BSC conversion with variable NSE support will be implemented as a follow-up. [ghstack-poisoned]
- Loading branch information
Showing
562 changed files
with
16,463 additions
and
14,294 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
ffd056dc1510bdfecafb689ed87601055694f3e6 | ||
41c44bc1d080d6cf063419a4166732b983b84eef |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
a67cc87a33a3f713aebf5299bdeb2672c98e0bc5 | ||
a4f53308b2d0f1aa9191686e326f45c26053f686 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
e0dcc3171c8024ab288551d105fba24fbfae7332 | ||
307af4313d2b0b0236618ef837959a41068cc272 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.