Skip to content

Commit

Permalink
Support different NSE in batches of CSR and CSC tensors
Browse files Browse the repository at this point in the history
ghstack-source-id: 19d3b9616b62846c14ef45a85ea4468db42e0836
Pull Request resolved: #84843
  • Loading branch information
pearu committed Sep 19, 2022
1 parent 082a6d4 commit 99f96a4
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 53 deletions.
103 changes: 85 additions & 18 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,17 +792,94 @@ std::pair<Tensor, Tensor> _not_zero_mask_to_col_row_indices(
return std::pair<Tensor, Tensor>(col_indices, row_indices);
}

Tensor dense_to_batched_sparse_compressed_nonblock(const Tensor& self, const Layout& target_layout) {
ScalarType index_dtype = at::kLong;
Device index_device = self.device();
auto n_batch_dim = self.dim() - 2;
TORCH_INTERNAL_ASSERT(n_batch_dim > 0);
int compressed_dim_size, plain_dim_size;
std::tie(compressed_dim_size, plain_dim_size) = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(target_layout, "dense_to_batched_sparse_compressed_nonblock",
[&] { return std::make_tuple(self.size(-2), self.size(-1)); },
[&] { return std::make_tuple(self.size(-1), self.size(-2)); });
auto batchsize = self.sizes().slice(0, n_batch_dim);
auto nbatches = size_from_dim_(0, batchsize);
TORCH_CHECK(nbatches > 0,
"to_sparse_",
sparse_csr::layoutToString(target_layout, false, true),
": Expected product of batch dimensions to be non-zero.");
Tensor input = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(target_layout, "dense_to_batched_sparse_compressed_nonblock",
[&] { return self; },
[&] { return self.transpose(-2, -1); });
Tensor non_zero_mask = (input != 0).flatten(0, n_batch_dim-1).flatten(-2, -1);
Tensor nse = non_zero_mask.sum(1);
int64_t max_nse = AT_DISPATCH_INTEGRAL_TYPES(nse.scalar_type(), "dense_to_batched_sparse_compressed_nonblock",
[&]() -> int64_t { return nse.max().item<scalar_t>(); });
Tensor flat_uncompressed_indices = at::native::arange(nbatches * compressed_dim_size, index_dtype, kStrided, index_device)
.repeat_interleave(plain_dim_size)
.flatten()
.masked_select(non_zero_mask.flatten());
Tensor flat_compressed_indices = at::_convert_indices_from_coo_to_csr(flat_uncompressed_indices, nbatches * compressed_dim_size, false /*out_int32*/);
Tensor batch_compressed_indices = at::zeros({nbatches, compressed_dim_size + 1}, flat_compressed_indices.options());
if (compressed_dim_size > 0) {
batch_compressed_indices.narrow(1, 1, compressed_dim_size)
.copy_(flat_compressed_indices.slice(0, 1, c10::nullopt, 1).reshape({nbatches, compressed_dim_size})
- flat_compressed_indices.slice(0, 0, -1, compressed_dim_size).reshape({nbatches, 1}));
}
auto compressed_indices_size = DimVector(batchsize);
compressed_indices_size.push_back(compressed_dim_size + 1);
Tensor compressed_indices = batch_compressed_indices.reshape(compressed_indices_size);

Tensor batch_flat_indices = at::zeros({nbatches, max_nse}, flat_compressed_indices.options());
Tensor non_zero_indices = non_zero_mask.flatten().nonzero().flatten();
Tensor nse_cpu = nse.cpu();
AT_DISPATCH_INTEGRAL_TYPES(nse_cpu.scalar_type(), "dense_to_batched_sparse_compressed_nonblock",
[&]() {
scalar_t cnse_im1 = 0, cnse_i = 0;
scalar_t* nse_ptr = nse_cpu.data_ptr<scalar_t>();
for (const auto i : c10::irange(0, nbatches)) {
const auto nse_i = nse_ptr[i];
cnse_i += nse_i;
batch_flat_indices.select(0, i)
.narrow(0, 0, nse_i)
.copy_(non_zero_indices.slice(0, cnse_im1, cnse_i, 1));
cnse_im1 = cnse_i;
}
});
Tensor flat_ordering = batch_flat_indices.flatten();

// plain_indices and values have the same size because dense
// dimensionality is 0.
auto values_size = DimVector(batchsize);
values_size.push_back(max_nse);

Tensor plain_indices = at::native::arange(0, plain_dim_size, 1, index_dtype, kStrided, index_device)
.repeat(nbatches * compressed_dim_size)
.index_select(0, flat_ordering)
.unflatten(0, values_size);

Tensor values = input.flatten()
.index_select(0, flat_ordering)
.unflatten(0, values_size);

return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
values,
self.sizes(),
self.scalar_type(),
target_layout,
self.device());
}

// Sparse layout conversions Start

Tensor dense_to_sparse_csr(const Tensor& self) {
auto n_batch_dim = self.dim() - 2;
auto values = self;
auto not_zero_mask = self != 0;

if (n_batch_dim > 0) {
dense_to_sparse_compressed_prepare_check_mask_values_batched(
Layout::SparseCsr, values, not_zero_mask, n_batch_dim);
return dense_to_batched_sparse_compressed_nonblock(self, Layout::SparseCsr);
}
auto values = self;
auto not_zero_mask = self != 0;

Tensor col_indices;
Tensor row_indices;
Expand All @@ -815,10 +892,6 @@ Tensor dense_to_sparse_csr(const Tensor& self) {
values = values.flatten().index_select(0, mask_indices);
}

if (n_batch_dim > 0) {
reshape_2d_sparse_compressed_members_to_nd_batched(
self.sizes(), n_batch_dim, crow_indices, col_indices, values);
}
return at::native::_sparse_csr_tensor_unsafe(
crow_indices,
col_indices,
Expand All @@ -831,13 +904,11 @@ Tensor dense_to_sparse_csr(const Tensor& self) {

Tensor dense_to_sparse_csc(const Tensor& self) {
auto n_batch_dim = self.dim() - 2;
auto values = self;
auto not_zero_mask = self != 0;

if (n_batch_dim > 0) {
dense_to_sparse_compressed_prepare_check_mask_values_batched(
Layout::SparseCsc, values, not_zero_mask, n_batch_dim);
return dense_to_batched_sparse_compressed_nonblock(self, Layout::SparseCsc);
}
auto values = self;
auto not_zero_mask = self != 0;

Tensor col_indices;
Tensor row_indices;
Expand All @@ -855,10 +926,6 @@ Tensor dense_to_sparse_csc(const Tensor& self) {
values = values.index_select(0, mask_indices);
}

if (n_batch_dim > 0) {
reshape_2d_sparse_compressed_members_to_nd_batched(
self.sizes(), n_batch_dim, ccol_indices, row_indices, values);
}
return at::native::_sparse_csc_tensor_unsafe(
ccol_indices,
row_indices,
Expand Down
15 changes: 12 additions & 3 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,10 +789,19 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {

if (dim < n_batch) {
// Selecting batch dimension
Tensor item_compressed_indices = compressed_indices.select(dim, index);
Tensor item_plain_indices = plain_indices.select(dim, index);
Tensor item_values = self.values().select(dim, index);
int64_t nse = AT_DISPATCH_INTEGRAL_TYPES(item_compressed_indices.scalar_type(), "select",
[&]() -> int64_t { return item_compressed_indices.select(-1, -1).max().item<scalar_t>(); });
if (nse < item_plain_indices.size(-1)) {
item_plain_indices = item_plain_indices.slice(0, 0, nse, 1);
item_values = item_values.slice(0, 0, nse, 1);
}
return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices.select(dim, index),
plain_indices.select(dim, index),
self.values().select(dim, index),
item_compressed_indices,
item_plain_indices,
item_values,
new_sizes,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
Expand Down
24 changes: 21 additions & 3 deletions aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ _check_last_cidx_is_nnz(const index_t& cidx, const index_t& nnz) {
}
}

// Invariant 5.2 modified, to allow different NSE in between batches.
// compressed_index[..., -1] <= nnz.
template <CDimName cdim_name, typename index_t>
INVARIANT_CHECK_FUNC_API
_check_last_cidx_is_nnz_or_less(const index_t& cidx, const index_t& nnz) {
const bool invariant = cidx <= nnz;
if (cdim_name == CDimName::CRow) {
_assert(invariant, "`crow_indices[..., -1] <= nnz` is not satisfied.");
}
else {
_assert(invariant, "`ccol_indices[..., -1] <= nnz` is not satisfied.");
}
}

// Invariant 5.3
// 0 <= compressed_indices[..., 1:] - compressed_indices[..., :-1] <= plain_dim.
template <CDimName cdim_name, typename index_t>
Expand Down Expand Up @@ -253,11 +267,11 @@ void _validate_compressed_sparse_indices_kernel(
.add_input(batch_idx)
.build();

AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, &idx, dim, nnz] () {
AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, &idx, dim, nnz, batch_count] () {
const auto* RESTRICT ptr_idx = idx.data_ptr<index_t>();
const auto zero = index_t {0};
KernelLauncher::launch(iter,
[zero, dim, nnz, ptr_idx] FUNCAPI (
[zero, dim, nnz, batch_count, ptr_idx] FUNCAPI (
index_t cidx_first,
index_t cidx_last,
index_t cidx_curr,
Expand All @@ -266,7 +280,11 @@ void _validate_compressed_sparse_indices_kernel(
// Invariant 5.1
_check_first_cidx_is_zero<cdim_name, index_t>(cidx_first, zero);
// Invariant 5.2
_check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
if (batch_count>1) {
_check_last_cidx_is_nnz_or_less<cdim_name, index_t>(cidx_last, nnz);
} else {
_check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
}
// Invariant 5.3
_check_cidx_nondecreasing_locally_bounded_sequence<cdim_name, index_t>(cidx_curr, cidx_next, zero, dim);
// Invariant 5.6
Expand Down
9 changes: 0 additions & 9 deletions test/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,6 @@ def sample_inputs_generator():
other = sample_input.args[0]
nse = (other != 0).sum((-2, -1))
for other_layout in layouts:
if nse.max() != nse.min() and other_layout in {torch.sparse_csr, torch.sparse_csc,
torch.sparse_bsr, torch.sparse_bsc}:
# TODO: sparse compressed tensors
# require equal NSE in batches (see PR
# 84834 that removes this
# restriction). After PR 84834 (or
# equivalent) lands, remove this
# if-block
continue
if other.layout != other_layout:
yield SampleInput(sample_input.input.clone(),
args=(torch.sparse._to_layout(other, other_layout),),
Expand Down
12 changes: 8 additions & 4 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2782,7 +2782,7 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
dense_back = sparse.to_dense()
self.assertEqual(dense, dense_back)

# if batches have different nnz we expect the conversion to throw
# if batches have different nnz we expect the conversion to throw for blocked layouts
mask_0 = mask[0]
mask_1 = mask[0].clone().fill_(True)
mask_2 = mask[0].clone().fill_(False)
Expand All @@ -2791,9 +2791,13 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
mask = torch.stack([(mask_0, mask_1, mask_2)[i % 3] for i in range(n_batch)], dim=0).reshape(batch_shape + mask_0.shape)
dense = make_tensor(shape, dtype=torch.float, device=device)
dense = dense * mask
msg = "Expect the same number of specified elements per batch."
with self.assertRaisesRegex(RuntimeError, msg):
self._convert_to_layout(dense, layout, blocksize)
if layout in blocked_layouts:
msg = "Expect the same number of specified elements per batch."
with self.assertRaisesRegex(RuntimeError, msg):
self._convert_to_layout(dense, layout, blocksize)
else:
sparse = self._convert_to_layout(dense, layout, blocksize)
check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)

# Should throw if there is a zero in the batch size
dense = make_tensor((0,) + shape, dtype=torch.float, device=device)
Expand Down
27 changes: 11 additions & 16 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,22 +1046,17 @@ def template_reference_inputs_bmm_sparse(self, input_layout, device, dtype, requ
args=(b_sparse.detach().requires_grad_(requires_grad),),
kwargs=sample_input.kwargs)

if input_layout not in {torch.sparse_csr, torch.sparse_csc,
torch.sparse_bsr, torch.sparse_bsc}:
# TODO: sparse compressed tensors require equal NSE in batches
# (see PR 84834 that removes this restriction). After PR 84834
# (or equivalent) lands, enable this if-block for all layouts.
a = make_tensor((2 * S, S + 2, S + 3), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
a[0].fill_(0)
a[2].fill_(0)
a[3, 1].fill_(0)
a[4, :, 1].fill_(0)
b = make_tensor((2 * S, S + 3, S + 1), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
b[1].fill_(0)
b[2].fill_(0)
b[3, 2].fill_(0)
b[4, :, 2].fill_(0)
yield SampleInput(torch.sparse._to_layout(a, input_layout), args=(b,))
a = make_tensor((2 * S, S + 2, S + 3), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
a[0].fill_(0)
a[2].fill_(0)
a[3, 1].fill_(0)
a[4, :, 1].fill_(0)
b = make_tensor((2 * S, S + 3, S + 1), dtype=dtype, device=device, requires_grad=requires_grad, exclude_zero=True)
b[1].fill_(0)
b[2].fill_(0)
b[3, 2].fill_(0)
b[4, :, 2].fill_(0)
yield SampleInput(torch.sparse._to_layout(a, input_layout), args=(b,))


def reference_inputs_sparse_coo_bmm(self, device, dtype, requires_grad, **kwargs):
Expand Down

0 comments on commit 99f96a4

Please sign in to comment.