Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strided to batch BSR/BSC conversion fails when the number of zeros per block varies while the number of blocks per patch is constant #98495

Open
pearu opened this issue Apr 6, 2023 · 3 comments
Assignees
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pearu
Copy link
Collaborator

pearu commented Apr 6, 2023

Issue description

As in the title.

Code example

>>> torch.tensor([[[1, 2]], [[3, 4]]]).to_sparse_bsr((1, 1))
tensor(crow_indices=tensor([[0, 2],
                            [0, 2]]),
       col_indices=tensor([[0, 1],
                           [0, 1]]),
       values=tensor([[[[1]],

                       [[2]]],


                      [[[3]],

                       [[4]]]]), size=(2, 1, 2), nnz=2,
       layout=torch.sparse_bsr)
>>> torch.tensor([[[1, 2]], [[0, 4]]]).to_sparse_bsr((1, 1))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expect the same number of specified elements per batch.
>>> torch.tensor([[[1, 0]], [[0, 4]]]).to_sparse_bsr((1, 1))
tensor(crow_indices=tensor([[0, 1],
                            [0, 1]]),
       col_indices=tensor([[0],
                           [1]]),
       values=tensor([[[[1]]],


                      [[[4]]]]), size=(2, 1, 2), nnz=1,
       layout=torch.sparse_bsr)

Notice that in the failing conversion example, the number of zeros in the first block is 0 and in the second block it is 1.

Apparently, the check logic in

auto nse_per_batch = mask.select(0, 0).sum().expand(mask.size(0));
TORCH_CHECK(
mask.sum({-2, -1}).equal(nse_per_batch),
"Expect the same number of specified elements per batch.");

is flawed for BSR and BSC conversion cases.

System Info

  • PyTorch version: master

cc @alexsamardzic @nikitaved @cpuhrsch @amjames @bhosmer

@pearu
Copy link
Collaborator Author

pearu commented Apr 6, 2023

This issue applies to all sparse compressed layouts:

>>> torch.tensor([[[1, 0]], [[3, 4]]]).to_sparse_csr()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expect the same number of specified elements per batch.

For CSR/CSC cases the failure could be considered a result of an unsupported feature but for BSR/BSC cases the bug is real.

@pearu pearu changed the title Strided to BSR conversion fails when blocks contains different number of zeros Strided to CSR/CSC/BSR/BSC conversion fails when the number of zeros per element/block varies Apr 6, 2023
@pearu pearu changed the title Strided to CSR/CSC/BSR/BSC conversion fails when the number of zeros per element/block varies Strided to CSR/CSC/BSR/BSC conversion fails when the number of zeros per batch varies Apr 6, 2023
@pearu pearu changed the title Strided to CSR/CSC/BSR/BSC conversion fails when the number of zeros per batch varies Strided to batch BSR/BSC conversion fails when the number of zeros per block varies while the number of blocks per patch is constant Apr 6, 2023
@zou3519 zou3519 added module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 10, 2023
@cpuhrsch
Copy link
Contributor

@pearu - Thanks for discovering this! Can you send a fix for this?

@pearu pearu self-assigned this May 2, 2023
@pearu pearu added this to To do in Sparse tensors via automation May 2, 2023
@pearu
Copy link
Collaborator Author

pearu commented Jun 26, 2023

This issue is blocked by the lack of support for variable NSE in batches of sparse tensors (the BSR example that is provided in the description, requires variable NSE support). The solution provided in #84843 requires further discussion, see #104193.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Development

No branches or pull requests

3 participants