Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different NSE in batches of CSR and CSC tensors #84843

Draft
wants to merge 5 commits into
base: gh/pearu/56/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
103 changes: 85 additions & 18 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,17 +792,94 @@ std::pair<Tensor, Tensor> _not_zero_mask_to_col_row_indices(
return std::pair<Tensor, Tensor>(col_indices, row_indices);
}

Tensor dense_to_batched_sparse_compressed_nonblock(const Tensor& self, const Layout& target_layout) {
ScalarType index_dtype = at::kLong;
Device index_device = self.device();
auto n_batch_dim = self.dim() - 2;
TORCH_INTERNAL_ASSERT(n_batch_dim > 0);
int compressed_dim_size, plain_dim_size;
std::tie(compressed_dim_size, plain_dim_size) = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(target_layout, "dense_to_batched_sparse_compressed_nonblock",
[&] { return std::make_tuple(self.size(-2), self.size(-1)); },
[&] { return std::make_tuple(self.size(-1), self.size(-2)); });
auto batchsize = self.sizes().slice(0, n_batch_dim);
auto nbatches = size_from_dim_(0, batchsize);
TORCH_CHECK(nbatches > 0,
"to_sparse_",
sparse_csr::layoutToString(target_layout, false, true),
": Expected product of batch dimensions to be non-zero.");
Tensor input = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(target_layout, "dense_to_batched_sparse_compressed_nonblock",
[&] { return self; },
[&] { return self.transpose(-2, -1); });
Tensor non_zero_mask = (input != 0).flatten(0, n_batch_dim-1).flatten(-2, -1);
Tensor nse = non_zero_mask.sum(1);
int64_t max_nse = AT_DISPATCH_INTEGRAL_TYPES(nse.scalar_type(), "dense_to_batched_sparse_compressed_nonblock",
[&]() -> int64_t { return nse.max().item<scalar_t>(); });
Tensor flat_uncompressed_indices = at::native::arange(nbatches * compressed_dim_size, index_dtype, kStrided, index_device)
.repeat_interleave(plain_dim_size)
.flatten()
.masked_select(non_zero_mask.flatten());
Tensor flat_compressed_indices = at::_convert_indices_from_coo_to_csr(flat_uncompressed_indices, nbatches * compressed_dim_size, false /*out_int32*/);
Tensor batch_compressed_indices = at::zeros({nbatches, compressed_dim_size + 1}, flat_compressed_indices.options());
if (compressed_dim_size > 0) {
batch_compressed_indices.narrow(1, 1, compressed_dim_size)
.copy_(flat_compressed_indices.slice(0, 1, c10::nullopt, 1).reshape({nbatches, compressed_dim_size})
- flat_compressed_indices.slice(0, 0, -1, compressed_dim_size).reshape({nbatches, 1}));
}
auto compressed_indices_size = DimVector(batchsize);
compressed_indices_size.push_back(compressed_dim_size + 1);
Tensor compressed_indices = batch_compressed_indices.reshape(compressed_indices_size);

Tensor batch_flat_indices = at::zeros({nbatches, max_nse}, flat_compressed_indices.options());
Tensor non_zero_indices = non_zero_mask.flatten().nonzero().flatten();
Tensor nse_cpu = nse.cpu();
AT_DISPATCH_INTEGRAL_TYPES(nse_cpu.scalar_type(), "dense_to_batched_sparse_compressed_nonblock",
[&]() {
scalar_t cnse_im1 = 0, cnse_i = 0;
scalar_t* nse_ptr = nse_cpu.data_ptr<scalar_t>();
for (const auto i : c10::irange(0, nbatches)) {
const auto nse_i = nse_ptr[i];
cnse_i += nse_i;
batch_flat_indices.select(0, i)
.narrow(0, 0, nse_i)
.copy_(non_zero_indices.slice(0, cnse_im1, cnse_i, 1));
cnse_im1 = cnse_i;
}
});
Tensor flat_ordering = batch_flat_indices.flatten();

// plain_indices and values have the same size because dense
// dimensionality is 0.
auto values_size = DimVector(batchsize);
values_size.push_back(max_nse);

Tensor plain_indices = at::native::arange(0, plain_dim_size, 1, index_dtype, kStrided, index_device)
.repeat(nbatches * compressed_dim_size)
.index_select(0, flat_ordering)
.unflatten(0, values_size);

Tensor values = input.flatten()
.index_select(0, flat_ordering)
.unflatten(0, values_size);

return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
values,
self.sizes(),
self.scalar_type(),
target_layout,
self.device());
}

// Sparse layout conversions Start

Tensor dense_to_sparse_csr(const Tensor& self) {
auto n_batch_dim = self.dim() - 2;
auto values = self;
auto not_zero_mask = self != 0;

if (n_batch_dim > 0) {
dense_to_sparse_compressed_prepare_check_mask_values_batched(
Layout::SparseCsr, values, not_zero_mask, n_batch_dim);
return dense_to_batched_sparse_compressed_nonblock(self, Layout::SparseCsr);
}
auto values = self;
auto not_zero_mask = self != 0;

Tensor col_indices;
Tensor row_indices;
Expand All @@ -815,10 +892,6 @@ Tensor dense_to_sparse_csr(const Tensor& self) {
values = values.flatten().index_select(0, mask_indices);
}

if (n_batch_dim > 0) {
reshape_2d_sparse_compressed_members_to_nd_batched(
self.sizes(), n_batch_dim, crow_indices, col_indices, values);
}
return at::native::_sparse_csr_tensor_unsafe(
crow_indices,
col_indices,
Expand All @@ -831,13 +904,11 @@ Tensor dense_to_sparse_csr(const Tensor& self) {

Tensor dense_to_sparse_csc(const Tensor& self) {
auto n_batch_dim = self.dim() - 2;
auto values = self;
auto not_zero_mask = self != 0;

if (n_batch_dim > 0) {
dense_to_sparse_compressed_prepare_check_mask_values_batched(
Layout::SparseCsc, values, not_zero_mask, n_batch_dim);
return dense_to_batched_sparse_compressed_nonblock(self, Layout::SparseCsc);
}
auto values = self;
auto not_zero_mask = self != 0;

Tensor col_indices;
Tensor row_indices;
Expand All @@ -855,10 +926,6 @@ Tensor dense_to_sparse_csc(const Tensor& self) {
values = values.index_select(0, mask_indices);
}

if (n_batch_dim > 0) {
reshape_2d_sparse_compressed_members_to_nd_batched(
self.sizes(), n_batch_dim, ccol_indices, row_indices, values);
}
return at::native::_sparse_csc_tensor_unsafe(
ccol_indices,
row_indices,
Expand Down
15 changes: 12 additions & 3 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,10 +789,19 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {

if (dim < n_batch) {
// Selecting batch dimension
Tensor item_compressed_indices = compressed_indices.select(dim, index);
Tensor item_plain_indices = plain_indices.select(dim, index);
Tensor item_values = self.values().select(dim, index);
int64_t nse = AT_DISPATCH_INTEGRAL_TYPES(item_compressed_indices.scalar_type(), "select",
[&]() -> int64_t { return item_compressed_indices.select(-1, -1).max().item<scalar_t>(); });
if (nse < item_plain_indices.size(-1)) {
item_plain_indices = item_plain_indices.slice(0, 0, nse, 1);
item_values = item_values.slice(0, 0, nse, 1);
}
return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices.select(dim, index),
plain_indices.select(dim, index),
self.values().select(dim, index),
item_compressed_indices,
item_plain_indices,
item_values,
new_sizes,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
Expand Down
24 changes: 21 additions & 3 deletions aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ _check_last_cidx_is_nnz(const index_t& cidx, const index_t& nnz) {
}
}

// Invariant 5.2 modified, to allow different NSE in between batches.
// compressed_index[..., -1] <= nnz.
template <CDimName cdim_name, typename index_t>
INVARIANT_CHECK_FUNC_API
_check_last_cidx_is_nnz_or_less(const index_t& cidx, const index_t& nnz) {
const bool invariant = cidx <= nnz;
if (cdim_name == CDimName::CRow) {
_assert(invariant, "`crow_indices[..., -1] <= nnz` is not satisfied.");
}
else {
_assert(invariant, "`ccol_indices[..., -1] <= nnz` is not satisfied.");
}
}

// Invariant 5.3
// 0 <= compressed_indices[..., 1:] - compressed_indices[..., :-1] <= plain_dim.
template <CDimName cdim_name, typename index_t>
Expand Down Expand Up @@ -253,11 +267,11 @@ void _validate_compressed_sparse_indices_kernel(
.add_input(batch_idx)
.build();

AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, &idx, dim, nnz] () {
AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, &idx, dim, nnz, batch_count] () {
const auto* RESTRICT ptr_idx = idx.data_ptr<index_t>();
const auto zero = index_t {0};
KernelLauncher::launch(iter,
[zero, dim, nnz, ptr_idx] FUNCAPI (
[zero, dim, nnz, batch_count, ptr_idx] FUNCAPI (
index_t cidx_first,
index_t cidx_last,
index_t cidx_curr,
Expand All @@ -266,7 +280,11 @@ void _validate_compressed_sparse_indices_kernel(
// Invariant 5.1
_check_first_cidx_is_zero<cdim_name, index_t>(cidx_first, zero);
// Invariant 5.2
_check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
if (batch_count>1) {
_check_last_cidx_is_nnz_or_less<cdim_name, index_t>(cidx_last, nnz);
} else {
_check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
}
// Invariant 5.3
_check_cidx_nondecreasing_locally_bounded_sequence<cdim_name, index_t>(cidx_curr, cidx_next, zero, dim);
// Invariant 5.6
Expand Down
9 changes: 0 additions & 9 deletions test/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,6 @@ def sample_inputs_generator():
other = sample_input.args[0]
nse = (other != 0).sum((-2, -1))
for other_layout in layouts:
if nse.max() != nse.min() and other_layout in {torch.sparse_csr, torch.sparse_csc,
torch.sparse_bsr, torch.sparse_bsc}:
# TODO: sparse compressed tensors
# require equal NSE in batches (see PR
# 84834 that removes this
# restriction). After PR 84834 (or
# equivalent) lands, remove this
# if-block
continue
if other.layout != other_layout:
yield SampleInput(sample_input.input.clone(),
args=(torch.sparse._to_layout(other, other_layout),),
Expand Down
12 changes: 8 additions & 4 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2782,7 +2782,7 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
dense_back = sparse.to_dense()
self.assertEqual(dense, dense_back)

# if batches have different nnz we expect the conversion to throw
# if batches have different nnz we expect the conversion to throw for blocked layouts
mask_0 = mask[0]
mask_1 = mask[0].clone().fill_(True)
mask_2 = mask[0].clone().fill_(False)
Expand All @@ -2791,9 +2791,13 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
mask = torch.stack([(mask_0, mask_1, mask_2)[i % 3] for i in range(n_batch)], dim=0).reshape(batch_shape + mask_0.shape)
dense = make_tensor(shape, dtype=torch.float, device=device)
dense = dense * mask
msg = "Expect the same number of specified elements per batch."
with self.assertRaisesRegex(RuntimeError, msg):
self._convert_to_layout(dense, layout, blocksize)
if layout in blocked_layouts:
msg = "Expect the same number of specified elements per batch."
with self.assertRaisesRegex(RuntimeError, msg):
self._convert_to_layout(dense, layout, blocksize)
else:
sparse = self._convert_to_layout(dense, layout, blocksize)
check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)

# Should throw if there is a zero in the batch size
dense = make_tensor((0,) + shape, dtype=torch.float, device=device)
Expand Down
27 changes: 11 additions & 16 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,22 +1046,17 @@ def template_reference_inputs_bmm_sparse(self, input_layout, device, dtype, requ
args=(b_sparse.detach().requires_grad_(requires_grad),),
kwargs=sample_input.kwargs)

if input_layout not in {torch.sparse_csr, torch.sparse_csc,
torch.sparse_bsr, torch.sparse_bsc}:
# TODO: sparse compressed tensors require equal NSE in batches
# (see PR 84834 that removes this restriction). After PR 84834
# (or equivalent) lands, enable this if-block for all layouts.
a = make_tensor((2 * S, S + 2, S + 3), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
a[0].fill_(0)
a[2].fill_(0)
a[3, 1].fill_(0)
a[4, :, 1].fill_(0)
b = make_tensor((2 * S, S + 3, S + 1), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
b[1].fill_(0)
b[2].fill_(0)
b[3, 2].fill_(0)
b[4, :, 2].fill_(0)
yield SampleInput(torch.sparse._to_layout(a, input_layout), args=(b,))
a = make_tensor((2 * S, S + 2, S + 3), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
a[0].fill_(0)
a[2].fill_(0)
a[3, 1].fill_(0)
a[4, :, 1].fill_(0)
b = make_tensor((2 * S, S + 3, S + 1), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
b[1].fill_(0)
b[2].fill_(0)
b[3, 2].fill_(0)
b[4, :, 2].fill_(0)
yield SampleInput(torch.sparse._to_layout(a, input_layout), args=(b,))


def reference_inputs_sparse_coo_bmm(self, device, dtype, requires_grad, **kwargs):
Expand Down