Skip to content

Commit

Permalink
Support different NSE in batches of CSR and CSC tensors
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
pearu committed Sep 11, 2022
1 parent c794ee5 commit cf6565a
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 28 deletions.
106 changes: 88 additions & 18 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,17 +792,97 @@ std::pair<Tensor, Tensor> _not_zero_mask_to_col_row_indices(
return std::pair<Tensor, Tensor>(col_indices, row_indices);
}

Tensor dense_to_batched_sparse_compressed_nonblock(const Tensor& self, const Layout& target_layout) {
ScalarType index_dtype = at::kLong;
Device index_device = self.device();
auto n_batch_dim = self.dim() - 2;
TORCH_INTERNAL_ASSERT(n_batch_dim > 0);
auto nvalues = self.numel();
int compressed_dim_size, plain_dim_size;
std::tie(compressed_dim_size, plain_dim_size) = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(target_layout, "dense_to_batched_sparse_compressed_nonblock",
[&] { return std::make_tuple(self.size(-2), self.size(-1)); },
[&] { return std::make_tuple(self.size(-1), self.size(-2)); });
auto batchsize = self.sizes().slice(0, n_batch_dim);
auto nbatches = size_from_dim_(0, batchsize);
TORCH_CHECK(nbatches > 0,
"to_sparse_",
sparse_csr::layoutToString(target_layout, false, true),
": Expected product of batch dimensions to be non-zero.");
Tensor input = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(target_layout, "dense_to_batched_sparse_compressed_nonblock",
[&] { return self; },
[&] { return self.transpose(-2, -1); });
Tensor non_zero_mask = (input != 0).flatten(0, n_batch_dim-1).flatten(-2, -1);
Tensor nse = non_zero_mask.sum(1);
int64_t max_nse = AT_DISPATCH_INTEGRAL_TYPES(nse.scalar_type(), "dense_to_batched_sparse_compressed_nonblock",
[&]() -> int64_t { return nse.max().item<scalar_t>(); });
Tensor flat_uncompressed_indices = at::native::arange(nbatches * compressed_dim_size, index_dtype, kStrided, index_device)
.repeat_interleave(plain_dim_size)
.flatten()
.masked_select(non_zero_mask.flatten());
Tensor flat_compressed_indices = at::_convert_indices_from_coo_to_csr(flat_uncompressed_indices, nbatches * compressed_dim_size, false /*out_int32*/);
Tensor batch_compressed_indices = at::zeros({nbatches, compressed_dim_size + 1}, flat_compressed_indices.options());
if (compressed_dim_size > 0) {
batch_compressed_indices.narrow(1, 1, compressed_dim_size)
.copy_(flat_compressed_indices.slice(0, 1, c10::nullopt, 1).reshape({nbatches, compressed_dim_size})
- flat_compressed_indices.slice(0, 0, -1, compressed_dim_size).reshape({nbatches, 1}));
}
auto compressed_indices_size = DimVector(batchsize);
compressed_indices_size.push_back(compressed_dim_size + 1);
Tensor compressed_indices = batch_compressed_indices.reshape(compressed_indices_size);

Tensor batch_flat_indices = non_zero_mask.to(kByte).argsort(true /*stable*/, -1 /*dim*/, true /*descending*/)
.slice(1, 0, max_nse, 1);
/* computing batch_flat_indices using argsort as above is equivalent to
batch_flat_indices = torch.zeros((nbatches, max_nse), dtype=torch.int64)
for i in range(nbatches):
tmp = torch.arange(ncols * nrows).masked_select(non_zero_mask[i])
batch_flat_indices[i, :tmp.numel()] = tmp
(up to non_zero_mask elements with True values that we only care
about). topk cannot be used because its sort is not stable.
*/

if (nvalues > 0) {
batch_flat_indices.add_(at::native::arange(0, nvalues, compressed_dim_size * plain_dim_size, index_dtype, kStrided, index_device)
.reshape({nbatches, 1}));
}

Tensor flat_ordering = batch_flat_indices.flatten();

// plain_indices and values have the same size because dense
// dimensionality is 0.
auto values_size = DimVector(batchsize);
values_size.push_back(max_nse);

Tensor plain_indices = at::native::arange(0, plain_dim_size, 1, index_dtype, kStrided, index_device)
.repeat(nbatches * compressed_dim_size)
.index_select(0, flat_ordering)
.unflatten(0, values_size);

Tensor values = input.flatten()
.index_select(0, flat_ordering)
.unflatten(0, values_size);

return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
values,
self.sizes(),
self.scalar_type(),
target_layout,
self.device());
}

// Sparse layout conversions Start

Tensor dense_to_sparse_csr(const Tensor& self) {
auto n_batch_dim = self.dim() - 2;
auto values = self;
auto not_zero_mask = self != 0;

if (n_batch_dim > 0) {
dense_to_sparse_compressed_prepare_check_mask_values_batched(
Layout::SparseCsr, values, not_zero_mask, n_batch_dim);
return dense_to_batched_sparse_compressed_nonblock(self, Layout::SparseCsr);
}
auto values = self;
auto not_zero_mask = self != 0;

Tensor col_indices;
Tensor row_indices;
Expand All @@ -815,10 +895,6 @@ Tensor dense_to_sparse_csr(const Tensor& self) {
values = values.flatten().index_select(0, mask_indices);
}

if (n_batch_dim > 0) {
reshape_2d_sparse_compressed_members_to_nd_batched(
self.sizes(), n_batch_dim, crow_indices, col_indices, values);
}
return at::native::_sparse_csr_tensor_unsafe(
crow_indices,
col_indices,
Expand All @@ -831,13 +907,11 @@ Tensor dense_to_sparse_csr(const Tensor& self) {

Tensor dense_to_sparse_csc(const Tensor& self) {
auto n_batch_dim = self.dim() - 2;
auto values = self;
auto not_zero_mask = self != 0;

if (n_batch_dim > 0) {
dense_to_sparse_compressed_prepare_check_mask_values_batched(
Layout::SparseCsc, values, not_zero_mask, n_batch_dim);
return dense_to_batched_sparse_compressed_nonblock(self, Layout::SparseCsc);
}
auto values = self;
auto not_zero_mask = self != 0;

Tensor col_indices;
Tensor row_indices;
Expand All @@ -855,10 +929,6 @@ Tensor dense_to_sparse_csc(const Tensor& self) {
values = values.index_select(0, mask_indices);
}

if (n_batch_dim > 0) {
reshape_2d_sparse_compressed_members_to_nd_batched(
self.sizes(), n_batch_dim, ccol_indices, row_indices, values);
}
return at::native::_sparse_csc_tensor_unsafe(
ccol_indices,
row_indices,
Expand Down
15 changes: 12 additions & 3 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,10 +789,19 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {

if (dim < n_batch) {
// Selecting batch dimension
Tensor item_compressed_indices = compressed_indices.select(dim, index);
Tensor item_plain_indices = plain_indices.select(dim, index);
Tensor item_values = self.values().select(dim, index);
int64_t nse = AT_DISPATCH_INTEGRAL_TYPES(item_compressed_indices.scalar_type(), "select",
[&]() -> int64_t { return item_compressed_indices.select(-1, -1).max().item<scalar_t>(); });
if (nse < item_plain_indices.size(-1)) {
item_plain_indices = item_plain_indices.slice(0, 0, nse, 1);
item_values = item_values.slice(0, 0, nse, 1);
}
return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices.select(dim, index),
plain_indices.select(dim, index),
self.values().select(dim, index),
item_compressed_indices,
item_plain_indices,
item_values,
new_sizes,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
Expand Down
24 changes: 21 additions & 3 deletions aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ _check_last_cidx_is_nnz(const index_t& cidx, const index_t& nnz) {
}
}

// Invariant 5.2 modified, to allow different NSE in between batches.
// compressed_index[..., -1] <= nnz.
template <CDimName cdim_name, typename index_t>
INVARIANT_CHECK_FUNC_API
_check_last_cidx_is_nnz_or_less(const index_t& cidx, const index_t& nnz) {
const bool invariant = cidx <= nnz;
if (cdim_name == CDimName::CRow) {
_assert(invariant, "`crow_indices[..., -1] <= nnz` is not satisfied.");
}
else {
_assert(invariant, "`ccol_indices[..., -1] <= nnz` is not satisfied.");
}
}

// Invariant 5.3
// 0 <= compressed_indices[..., 1:] - compressed_indices[..., :-1] <= plain_dim.
template <CDimName cdim_name, typename index_t>
Expand Down Expand Up @@ -263,11 +277,11 @@ void _validate_compressed_sparse_indices_kernel(
.add_input(batch_idx)
.build();

AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, &idx, dim, nnz] () {
AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, &idx, dim, nnz, batch_count] () {
const auto* RESTRICT ptr_idx = idx.data_ptr<index_t>();
const auto zero = index_t {0};
KernelLauncher::launch(iter,
[zero, dim, nnz, ptr_idx] FUNCAPI (
[zero, dim, nnz, batch_count, ptr_idx] FUNCAPI (
index_t cidx_first,
index_t cidx_last,
index_t cidx_curr,
Expand All @@ -276,7 +290,11 @@ void _validate_compressed_sparse_indices_kernel(
// Invariant 5.1
_check_first_cidx_is_zero<cdim_name, index_t>(cidx_first, zero);
// Invariant 5.2
_check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
if (batch_count>1) {
_check_last_cidx_is_nnz_or_less<cdim_name, index_t>(cidx_last, nnz);
} else {
_check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
}
// Invariant 5.3
_check_cidx_nondecreasing_locally_bounded_sequence<cdim_name, index_t>(cidx_curr, cidx_next, zero, dim);
// Invariant 5.6
Expand Down
12 changes: 8 additions & 4 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,7 +2780,7 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
dense_back = sparse.to_dense()
self.assertEqual(dense, dense_back)

# if batches have different nnz we expect the conversion to throw
# if batches have different nnz we expect the conversion to throw for blocked layouts
mask_0 = mask[0]
mask_1 = mask[0].clone().fill_(True)
mask_2 = mask[0].clone().fill_(False)
Expand All @@ -2789,9 +2789,13 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
mask = torch.stack([(mask_0, mask_1, mask_2)[i % 3] for i in range(n_batch)], dim=0).reshape(batch_shape + mask_0.shape)
dense = make_tensor(shape, dtype=torch.float, device=device)
dense = dense * mask
msg = "Expect the same number of specified elements per batch."
with self.assertRaisesRegex(RuntimeError, msg):
self._convert_to_layout(dense, layout, blocksize)
if layout in blocked_layouts:
msg = "Expect the same number of specified elements per batch."
with self.assertRaisesRegex(RuntimeError, msg):
self._convert_to_layout(dense, layout, blocksize)
else:
sparse = self._convert_to_layout(dense, layout, blocksize)
check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)

# Should throw if there is a zero in the batch size
dense = make_tensor((0,) + shape, dtype=torch.float, device=device)
Expand Down

0 comments on commit cf6565a

Please sign in to comment.