Skip to content

Commit

Permalink
Add reference_inputs_sparse_coo/csr/csc_func attributes to OpInfo
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
pearu committed Sep 19, 2022
1 parent fc3bff7 commit 54241ca
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions torch/testing/_internal/opinfo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,15 @@ class OpInfo(object):
# function to generate sample inputs with sparse bsc layouts
sample_inputs_sparse_bsc_func: Callable = None

# function to generate a more thorough set of samples inputs with sparse coo layouts
reference_inputs_sparse_coo_func: Callable = None

# function to generate a more thorough set of samples inputs with sparse csr layouts
reference_inputs_sparse_csr_func: Callable = None

# function to generate a more thorough set of samples inputs with sparse csc layouts
reference_inputs_sparse_csc_func: Callable = None

# the following metadata relates to dtype support and is tested for correctness in test_ops.py

# dtypes this function works with on the CPU,
Expand Down Expand Up @@ -934,6 +943,12 @@ def __post_init__(self):
)
if self.reference_inputs_func is not None:
self.reference_inputs_func = torch.no_grad()(self.reference_inputs_func)
if self.reference_inputs_sparse_coo_func is not None:
self.reference_inputs_sparse_coo_func = torch.no_grad()(self.reference_inputs_sparse_coo_func)
if self.reference_inputs_sparse_csr_func is not None:
self.reference_inputs_sparse_csr_func = torch.no_grad()(self.reference_inputs_sparse_csr_func)
if self.reference_inputs_sparse_csc_func is not None:
self.reference_inputs_sparse_csc_func = torch.no_grad()(self.reference_inputs_sparse_csc_func)

if not self.autodiff_fusible_nodes:
self.autodiff_fusible_nodes = []
Expand Down Expand Up @@ -1159,6 +1174,39 @@ def sample_inputs_sparse_bsc(self, device, dtype, requires_grad=False, **kwargs)
self, device, dtype, requires_grad, **kwargs
)

def reference_inputs_sparse_coo(self, device, dtype, requires_grad=False, **kwargs):
"""Returns an iterable of SampleInputs that contain inputs with sparse
coo layout.
"""
if self.reference_inputs_sparse_coo_func is None:
return self.sample_inputs_sparse_coo_func(self, device, dtype, requires_grad, **kwargs)

return self.reference_inputs_sparse_coo_func(
self, device, dtype, requires_grad, **kwargs
)

def reference_inputs_sparse_csr(self, device, dtype, requires_grad=False, **kwargs):
"""Returns an iterable of SampleInputs that contain inputs with sparse
csr layout.
"""
if self.reference_inputs_sparse_csr_func is None:
return self.sample_inputs_sparse_csr_func(self, device, dtype, requires_grad, **kwargs)

return self.reference_inputs_sparse_csr_func(
self, device, dtype, requires_grad, **kwargs
)

def reference_inputs_sparse_csc(self, device, dtype, requires_grad=False, **kwargs):
"""Returns an iterable of SampleInputs that contain inputs with sparse
csc layout.
"""
if self.reference_inputs_sparse_csc_func is None:
return self.sample_inputs_sparse_csc_func(self, device, dtype, requires_grad, **kwargs)

return self.reference_inputs_sparse_csc_func(
self, device, dtype, requires_grad, **kwargs
)

def get_decorators(self, test_class, test_name, device, dtype):
"""Returns the decorators targeting the given test."""
result = []
Expand Down

0 comments on commit 54241ca

Please sign in to comment.