Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion from strided to batched sparse compressed tensor with a non-constant number of zeros in batches fails #104193

Open
pearu opened this issue Jun 26, 2023 · 2 comments
Assignees
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pearu
Copy link
Collaborator

pearu commented Jun 26, 2023

Issue description

As in the title.

The issue is created to discuss various approaches to supporting the strided-to-sparse-compressed conversion for the cases where the number of zeros in different batches is not equal.

Code example

Consider the following batched tensor of 2-by-2 tensors:

x = [[[0, 1],  # batch 1
      [2, 0]],
     [[3, 0],  # batch 2
      [0, 4]],
    ]

that can be represented as a batched CSR tensor:

>>> torch.tensor(x).to_sparse_csr()
tensor(crow_indices=tensor([[0, 1, 2],
                            [0, 1, 2]]),
       col_indices=tensor([[1, 0],
                           [0, 1]]),
       values=tensor([[1, 2],
                      [3, 4]]), size=(2, 2, 2), nnz=2, layout=torch.sparse_csr)

because both batches have equal number zeros: 2.

Next, consider a batched tensor with an unequal number of zeros in batches:

y = [[[0, 1],  # batch 1
      [2, 9]],
     [[3, 0],  # batch 2
      [0, 4]],
    ]

that currently cannot be represented as a batched CSR tensor:

>>> torch.tensor(y).to_sparse_csr()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expect the same number of specified elements per batch.

because the number of zeros in batches is different: 1 and 2, respectively.

Discussion

The following approaches exist to create a batched CSR tensor from batches having unequal numbers of zeros.

Approach 1: allow materialization of certain zeros

Notice that in the conversion of a strided to a CSR tensor, non-zero elements and specified elements are considered as synonyms. If we relax this condition and allow certain zero elements to become specified elements for the CSR representation, the example tensor y defined above can be represented as a batched CSR tensor. In fact, there exist many such representations, for example:

>>> torch.sparse_csr_tensor([[0, 1, 3], [0, 1, 3]], [[1, 0, 1], [0, 0, 1]], [[1, 2, 9], [3, 0, 4]]).to_dense()
tensor([[[0, 1],
         [2, 9]],
        [[3, 0],
         [0, 4]]])
>>> torch.sparse_csr_tensor([[0, 1, 3], [0, 2, 3]], [[1, 0, 1], [0, 1, 1]], [[1, 2, 9], [3, 0, 4]]).to_dense()
tensor([[[0, 1],
         [2, 9]],
        [[3, 0],
         [0, 4]]])
>>> torch.sparse_csr_tensor([[0, 2, 4], [0, 2, 4]], [[0, 1, 0, 1], [0, 1, 0, 1]], [[0, 1, 2, 9], [3, 0, 0, 4]]).to_dense()
tensor([[[0, 1],
         [2, 9]],
        [[3, 0],
         [0, 4]]])

that differ in the choice of materialized zeros.

Pros:

  • solves the issue using the existing sparse compressed tensor implementation

Cons:

  • requires introducing a complex and non-parallelizable strided->sparse compressed conversion algorithm that materializes zeros (non-optimal storage optimization) with the ambiguity of selecting materialized zeros (provisioning of a mask would resolve the ambiguity).
  • batches with smaller NSE have the same memory usage as the batches with the largest NSE
  • the maximum NSE in batches can be larger than the number of non-zeros in the batch of a minimum number of zeros

Approach 2: allow a variable number of specified elements in batches

A prototype of this approach is implemented at #84843

The example tensor y defined above can be represented as a batched CSR tensor uniquely:

>>> z = torch.tensor(y).to_sparse_csr()
>>> z
tensor(crow_indices=tensor([[0, 1, 3],
                            [0, 1, 2]]),
       col_indices=tensor([[1, 0, 1],
                           [0, 1, 0]]),
       values=tensor([[1, 2, 9],
                      [3, 4, 0]]), size=(2, 2, 2), nnz=3,
       layout=torch.sparse_csr)

where each batch has the expected NSE count:

>>> z[0]
tensor(crow_indices=tensor([0, 1, 3]),
       col_indices=tensor([1, 0, 1]),
       values=tensor([1, 2, 9]), size=(2, 2), nnz=3, layout=torch.sparse_csr)
>>> z[1]
tensor(crow_indices=tensor([0, 1, 2]),
       col_indices=tensor([0, 1]),
       values=tensor([3, 4]), size=(2, 2), nnz=2, layout=torch.sparse_csr)

Pros:

  • solves the issue without explicitly materializing zeros
  • the batched sparse compressed tensor representation is unique
  • the strided->sparse compressed conversion algorithm is simple and parallelizable
  • the maximum NSE is equal to the number of non-zeros in the batch of a minimum number of zeros
  • the performance of to_sparse_csr() on CUDA tensors increased by 15%

Cons:

  • requires relaxing the sparse compressed invariant: compressed_index[..., -1] == nnz becomes compressed_index[..., -1] <= nnz
  • batches with smaller NSE have the same memory usage as the batches with the largest NSE (the optimal storage would require ragged tensor support)
  • the indices and values of unused elements in batches with smaller NSE must be initialized to avoid confusing third-party libraries if the batched tensor has variable NSE (if the libraries treat batches as independent, there should be no confusion)

System Info

  • PyTorch version: main

cc @alexsamardzic @nikitaved @cpuhrsch @amjames @bhosmer

@pearu pearu added the module: sparse Related to torch.sparse label Jun 26, 2023
@pearu pearu added this to To do in Sparse tensors via automation Jun 26, 2023
@malfet malfet added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 26, 2023
@candyflower2005
Copy link

candyflower2005 commented Jul 19, 2023

@pearu Hey, thank you for raising this issue. Do you know if the PR will be merged anytime soon? The requirement of having the same mask for each sample is also a blocker for me

@pearu
Copy link
Collaborator Author

pearu commented Aug 26, 2023

A Python implementation for the Approach 1 is as follows:

def dense_to_mask(dense, sparse_dim=2, dense_dim=0):
    batch_dim = dense.ndim - sparse_dim - dense_dim

    # initial negative mask
    mask = dense == 0

    if dense_dim > 0:
        # eliminate dense dimensions
        for i in range(dense_dim):
            mask = mask.all(-1)

    if batch_dim > 1:
        # reduce multidimensional batches to 1D batch
        mask = mask.flatten(0, batch_dim - 1)
    elif batch_dim == 0:
        # add extra batch dimension
        mask = mask.unsqueeze(0)
    nof_zeros_per_batch = mask.sum(dim=tuple(range(1, sparse_dim + 1)))
    nof_zeros_per_batch_min = nof_zeros_per_batch.min()
    nof_zeros_per_batch_max = nof_zeros_per_batch.max()

    if nof_zeros_per_batch_min != nof_zeros_per_batch_max:
        # materialize zeros in smaller batches to meet the
        # equal-nnz-per-batch requirement
        for i in range(len(nof_zeros_per_batch)):
            n = nof_zeros_per_batch[i] - nof_zeros_per_batch_min
            if n > 0:
                submask = mask[i].flatten()
                materialize_zero_indices = submask.nonzero()[:n]
                submask[materialize_zero_indices] = False
                # todo: skip copy when flatten above produces a view
                mask[i].copy_(submask.unflatten(0, mask[i].shape))

    if batch_dim > 1:
        # restore multidimensional batches
        mask = mask.unflatten(0, dense.shape[:batch_dim])
    elif batch_dim == 0:
        # undo adding extra batch dimension
        mask = mask.squeeze(0)

    if dense_dim > 0:
        # restore dense dimension
        for i in range(dense_dim):
            mask = mask.unsqueeze(-1)
        mask = mask.expand(dense.shape)
    return ~mask

(tested on all sparse compressed layouts). Using the example in the description, we have:

In [5]: y.sparse_mask(dense_to_mask(y).to_sparse_csr())
Out[5]: 
tensor(crow_indices=tensor([[0, 1, 3],
                            [0, 2, 3]]),
       col_indices=tensor([[1, 0, 1],
                           [0, 1, 1]]),
       values=tensor([[1, 2, 9],
                      [3, 0, 4]]), size=(2, 2, 2), nnz=3,
       layout=torch.sparse_csr)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Development

No branches or pull requests

3 participants