Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different NSE in batches of CSR and CSC tensors #84843

Draft
wants to merge 5 commits into
base: gh/pearu/56/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Next Next commit
Support different NSE in batches of CSR and CSC tensors
[ghstack-poisoned]
  • Loading branch information
pearu committed Sep 11, 2022
commit cf6565a82a85a89a77b778f96f5c84c8996e6143
106 changes: 88 additions & 18 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,17 +792,97 @@ std::pair<Tensor, Tensor> _not_zero_mask_to_col_row_indices(
return std::pair<Tensor, Tensor>(col_indices, row_indices);
}

Tensor dense_to_batched_sparse_compressed_nonblock(const Tensor& self, const Layout& target_layout) {
ScalarType index_dtype = at::kLong;
Device index_device = self.device();
auto n_batch_dim = self.dim() - 2;
TORCH_INTERNAL_ASSERT(n_batch_dim > 0);
auto nvalues = self.numel();
int compressed_dim_size, plain_dim_size;
std::tie(compressed_dim_size, plain_dim_size) = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(target_layout, "dense_to_batched_sparse_compressed_nonblock",
[&] { return std::make_tuple(self.size(-2), self.size(-1)); },
[&] { return std::make_tuple(self.size(-1), self.size(-2)); });
auto batchsize = self.sizes().slice(0, n_batch_dim);
auto nbatches = size_from_dim_(0, batchsize);
TORCH_CHECK(nbatches > 0,
"to_sparse_",
sparse_csr::layoutToString(target_layout, false, true),
": Expected product of batch dimensions to be non-zero.");
Tensor input = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(target_layout, "dense_to_batched_sparse_compressed_nonblock",
[&] { return self; },
[&] { return self.transpose(-2, -1); });
Tensor non_zero_mask = (input != 0).flatten(0, n_batch_dim-1).flatten(-2, -1);
Tensor nse = non_zero_mask.sum(1);
int64_t max_nse = AT_DISPATCH_INTEGRAL_TYPES(nse.scalar_type(), "dense_to_batched_sparse_compressed_nonblock",
[&]() -> int64_t { return nse.max().item<scalar_t>(); });
Tensor flat_uncompressed_indices = at::native::arange(nbatches * compressed_dim_size, index_dtype, kStrided, index_device)
.repeat_interleave(plain_dim_size)
.flatten()
.masked_select(non_zero_mask.flatten());
Tensor flat_compressed_indices = at::_convert_indices_from_coo_to_csr(flat_uncompressed_indices, nbatches * compressed_dim_size, false /*out_int32*/);
Tensor batch_compressed_indices = at::zeros({nbatches, compressed_dim_size + 1}, flat_compressed_indices.options());
if (compressed_dim_size > 0) {
batch_compressed_indices.narrow(1, 1, compressed_dim_size)
.copy_(flat_compressed_indices.slice(0, 1, c10::nullopt, 1).reshape({nbatches, compressed_dim_size})
- flat_compressed_indices.slice(0, 0, -1, compressed_dim_size).reshape({nbatches, 1}));
}
auto compressed_indices_size = DimVector(batchsize);
compressed_indices_size.push_back(compressed_dim_size + 1);
Tensor compressed_indices = batch_compressed_indices.reshape(compressed_indices_size);

Tensor batch_flat_indices = non_zero_mask.to(kByte).argsort(true /*stable*/, -1 /*dim*/, true /*descending*/)
.slice(1, 0, max_nse, 1);
/* computing batch_flat_indices using argsort as above is equivalent to

batch_flat_indices = torch.zeros((nbatches, max_nse), dtype=torch.int64)
for i in range(nbatches):
tmp = torch.arange(ncols * nrows).masked_select(non_zero_mask[i])
batch_flat_indices[i, :tmp.numel()] = tmp

(up to non_zero_mask elements with True values that we only care
about). topk cannot be used because its sort is not stable.
*/

if (nvalues > 0) {
batch_flat_indices.add_(at::native::arange(0, nvalues, compressed_dim_size * plain_dim_size, index_dtype, kStrided, index_device)
.reshape({nbatches, 1}));
}

Tensor flat_ordering = batch_flat_indices.flatten();

// plain_indices and values have the same size because dense
// dimensionality is 0.
auto values_size = DimVector(batchsize);
values_size.push_back(max_nse);

Tensor plain_indices = at::native::arange(0, plain_dim_size, 1, index_dtype, kStrided, index_device)
.repeat(nbatches * compressed_dim_size)
.index_select(0, flat_ordering)
.unflatten(0, values_size);

Tensor values = input.flatten()
.index_select(0, flat_ordering)
.unflatten(0, values_size);

return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
values,
self.sizes(),
self.scalar_type(),
target_layout,
self.device());
}

// Sparse layout conversions Start

Tensor dense_to_sparse_csr(const Tensor& self) {
auto n_batch_dim = self.dim() - 2;
auto values = self;
auto not_zero_mask = self != 0;

if (n_batch_dim > 0) {
dense_to_sparse_compressed_prepare_check_mask_values_batched(
Layout::SparseCsr, values, not_zero_mask, n_batch_dim);
return dense_to_batched_sparse_compressed_nonblock(self, Layout::SparseCsr);
}
auto values = self;
auto not_zero_mask = self != 0;

Tensor col_indices;
Tensor row_indices;
Expand All @@ -815,10 +895,6 @@ Tensor dense_to_sparse_csr(const Tensor& self) {
values = values.flatten().index_select(0, mask_indices);
}

if (n_batch_dim > 0) {
reshape_2d_sparse_compressed_members_to_nd_batched(
self.sizes(), n_batch_dim, crow_indices, col_indices, values);
}
return at::native::_sparse_csr_tensor_unsafe(
crow_indices,
col_indices,
Expand All @@ -831,13 +907,11 @@ Tensor dense_to_sparse_csr(const Tensor& self) {

Tensor dense_to_sparse_csc(const Tensor& self) {
auto n_batch_dim = self.dim() - 2;
auto values = self;
auto not_zero_mask = self != 0;

if (n_batch_dim > 0) {
dense_to_sparse_compressed_prepare_check_mask_values_batched(
Layout::SparseCsc, values, not_zero_mask, n_batch_dim);
return dense_to_batched_sparse_compressed_nonblock(self, Layout::SparseCsc);
}
auto values = self;
auto not_zero_mask = self != 0;

Tensor col_indices;
Tensor row_indices;
Expand All @@ -855,10 +929,6 @@ Tensor dense_to_sparse_csc(const Tensor& self) {
values = values.index_select(0, mask_indices);
}

if (n_batch_dim > 0) {
reshape_2d_sparse_compressed_members_to_nd_batched(
self.sizes(), n_batch_dim, ccol_indices, row_indices, values);
}
return at::native::_sparse_csc_tensor_unsafe(
ccol_indices,
row_indices,
Expand Down
15 changes: 12 additions & 3 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,10 +789,19 @@ Tensor select_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {

if (dim < n_batch) {
// Selecting batch dimension
Tensor item_compressed_indices = compressed_indices.select(dim, index);
Tensor item_plain_indices = plain_indices.select(dim, index);
Tensor item_values = self.values().select(dim, index);
int64_t nse = AT_DISPATCH_INTEGRAL_TYPES(item_compressed_indices.scalar_type(), "select",
[&]() -> int64_t { return item_compressed_indices.select(-1, -1).max().item<scalar_t>(); });
if (nse < item_plain_indices.size(-1)) {
item_plain_indices = item_plain_indices.slice(0, 0, nse, 1);
item_values = item_values.slice(0, 0, nse, 1);
}
return at::native::_sparse_compressed_tensor_unsafe(
compressed_indices.select(dim, index),
plain_indices.select(dim, index),
self.values().select(dim, index),
item_compressed_indices,
item_plain_indices,
item_values,
new_sizes,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
Expand Down
24 changes: 21 additions & 3 deletions aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ _check_last_cidx_is_nnz(const index_t& cidx, const index_t& nnz) {
}
}

// Invariant 5.2 modified, to allow different NSE in between batches.
// compressed_index[..., -1] <= nnz.
template <CDimName cdim_name, typename index_t>
INVARIANT_CHECK_FUNC_API
_check_last_cidx_is_nnz_or_less(const index_t& cidx, const index_t& nnz) {
const bool invariant = cidx <= nnz;
if (cdim_name == CDimName::CRow) {
_assert(invariant, "`crow_indices[..., -1] <= nnz` is not satisfied.");
}
else {
_assert(invariant, "`ccol_indices[..., -1] <= nnz` is not satisfied.");
}
}

// Invariant 5.3
// 0 <= compressed_indices[..., 1:] - compressed_indices[..., :-1] <= plain_dim.
template <CDimName cdim_name, typename index_t>
Expand Down Expand Up @@ -263,11 +277,11 @@ void _validate_compressed_sparse_indices_kernel(
.add_input(batch_idx)
.build();

AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, &idx, dim, nnz] () {
AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, &idx, dim, nnz, batch_count] () {
const auto* RESTRICT ptr_idx = idx.data_ptr<index_t>();
const auto zero = index_t {0};
KernelLauncher::launch(iter,
[zero, dim, nnz, ptr_idx] FUNCAPI (
[zero, dim, nnz, batch_count, ptr_idx] FUNCAPI (
index_t cidx_first,
index_t cidx_last,
index_t cidx_curr,
Expand All @@ -276,7 +290,11 @@ void _validate_compressed_sparse_indices_kernel(
// Invariant 5.1
_check_first_cidx_is_zero<cdim_name, index_t>(cidx_first, zero);
// Invariant 5.2
_check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
if (batch_count>1) {
_check_last_cidx_is_nnz_or_less<cdim_name, index_t>(cidx_last, nnz);
} else {
_check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
}
// Invariant 5.3
_check_cidx_nondecreasing_locally_bounded_sequence<cdim_name, index_t>(cidx_curr, cidx_next, zero, dim);
// Invariant 5.6
Expand Down
12 changes: 8 additions & 4 deletions test/test_sparse_csr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,7 +2780,7 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
dense_back = sparse.to_dense()
self.assertEqual(dense, dense_back)

# if batches have different nnz we expect the conversion to throw
# if batches have different nnz we expect the conversion to throw for blocked layouts
mask_0 = mask[0]
mask_1 = mask[0].clone().fill_(True)
mask_2 = mask[0].clone().fill_(False)
Expand All @@ -2789,9 +2789,13 @@ def _generate_subject(sparse_shape, batch_shape, hybrid_shape):
mask = torch.stack([(mask_0, mask_1, mask_2)[i % 3] for i in range(n_batch)], dim=0).reshape(batch_shape + mask_0.shape)
dense = make_tensor(shape, dtype=torch.float, device=device)
dense = dense * mask
msg = "Expect the same number of specified elements per batch."
with self.assertRaisesRegex(RuntimeError, msg):
self._convert_to_layout(dense, layout, blocksize)
if layout in blocked_layouts:
msg = "Expect the same number of specified elements per batch."
with self.assertRaisesRegex(RuntimeError, msg):
self._convert_to_layout(dense, layout, blocksize)
else:
sparse = self._convert_to_layout(dense, layout, blocksize)
check_content(sparse, dense, blocksize=blocksize, batch_shape=batch_shape, hybrid_shape=hybrid_shape)

# Should throw if there is a zero in the batch size
dense = make_tensor((0,) + shape, dtype=torch.float, device=device)
Expand Down