Skip to content

Commit

Permalink
Update on "Support different NSE in batches of CSR and CSC tensors"
Browse files Browse the repository at this point in the history
This PR enables batched CSR/CSC tensors that batches may have different NSE counts.

For instance, with the current master we have
```python
>>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]])
>>> a.to_sparse_csr()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Expect the same number of specified elements per batch.
```
because the NSE of the first and second batches are different, 4 and 2, respectively.

This PR implements a strided-to-sparse-CSR/CSC conversion algorithm that supports CSR/CSC batches with different NSE counts. For instance:
```python
>>> a = torch.tensor([[[1, 2], [3, 4]], [[0, 12], [21, 0]]])
>>> b = a.to_sparse_csr()
>>> b
tensor(crow_indices=tensor([[0, 2, 4],
                            [0, 1, 2]]),
       col_indices=tensor([[0, 1, 0, 1],
                           [1, 0, 0, 0]]),
       values=tensor([[ 1,  2,  3,  4],
                      [12, 21,  0,  0]]), size=(2, 2, 2), nnz=4,
       layout=torch.sparse_csr)
>>> b[0]
tensor(crow_indices=tensor([0, 2, 4]),
       col_indices=tensor([0, 1, 0, 1]),
       values=tensor([1, 2, 3, 4]), size=(2, 2), nnz=4,
       layout=torch.sparse_csr)
>>> b[1]
tensor(crow_indices=tensor([0, 1, 2]),
       col_indices=tensor([1, 0]),
       values=tensor([12, 21]), size=(2, 2), nnz=2, layout=torch.sparse_csr)
```
that is, if the NSE of a batch is smaller than the maximum NSE over all batches, the corresponding rows in `col_indices`/`values` are padded with zeros as placeholders. Algorithms on batched CSR/CSC tensors must not access the padded parts of these tensors, that is, the algorithms should use the last element of the corresponding `crow_indices` row as the NSE value rather than the value of `.values().shape[0]` that holds the maximum NSE over all batches.

Performance-wise, the strided-to-sparse-CSR/CSC conversion algorithms in master and in this PR, are roughly equivalent:
```python
# master branch:
n [2]: a = torch.rand(10, 10, 1000, 1000)

In [3]: a = torch.where(a==0, 0.1, a)  # required for master, optional for the PR

In [4]: %timeit a.to_sparse_csr()
2.25 s ± 9.84 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: a_cuda = a.cuda()

In [6]: %timeit a_cuda.to_sparse_csr()
55.2 ms ± 6.95 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
```python
# this PR
In [2]: a = torch.rand(10, 10, 1000, 1000)

In [3]: a = torch.where(a==0, 0.1, a)  # required for master, optional for the PR

In [4]: %timeit a.to_sparse_csr()
2.12 s ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: a_cuda = a.cuda()

In [6]: %timeit a_cuda.to_sparse_csr(); torch.cuda.synchronize()
47.2 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
The performance of `to_sparse_csr()` on CUDA tensors increased by 15% with this PR.
 
A strided-to-sparse-BSR/BSC conversion with variable NSE support will be implemented as a follow-up.




[ghstack-poisoned]
  • Loading branch information
pearu committed Sep 12, 2022
2 parents 16bad3d + e87e40b commit 90fe9e9
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
f00dd2f35ecf6455d97237d63c70c9c8ec190940
e0dcc3171c8024ab288551d105fba24fbfae7332
19 changes: 10 additions & 9 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <c10/util/irange.h>

#include <vector>
#include <c10/core/SymIntArrayRef.h>

static const int MIOPEN_DIM_MAX = 5;

Expand Down Expand Up @@ -41,7 +42,7 @@ DEFINE_DISPATCH(batch_norm_cpu_backward_stub);
DEFINE_DISPATCH(renorm_scale_factor_stub);

namespace {
void check_dims_match_num_input_features(const char* arg_name, int64_t expected, int64_t actual){
void check_dims_match_num_input_features(const char* arg_name, SymInt expected, SymInt actual){
TORCH_CHECK(actual == expected,
arg_name, " should contain ", expected, " elements not ", actual);
}
Expand Down Expand Up @@ -443,14 +444,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});

auto num_features = input.sizes()[1];
auto num_features = input.sym_sizes()[1];

if (input.numel() == 0) {
if (input.sym_numel() == 0) {
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
auto options = input.options().dtype(
at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda()));
auto save_mean = at::empty({num_features}, options);
auto save_invstd = at::empty({num_features}, options);
auto save_mean = at::empty_symint(c10::SymIntArrayRef({num_features}), options);
auto save_invstd = at::empty_symint(c10::SymIntArrayRef({num_features}), options);

// don't return view of input, don't return empty tensor because it will break gradient chain
auto out = input.clone();
Expand All @@ -461,20 +462,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
}

if (running_mean.defined()) {
check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
check_dims_match_num_input_features("running_mean", num_features, running_mean.sym_numel());
} else if (!training) {
AT_ERROR("running_mean must be defined in evaluation mode");
}
if (running_var.defined()) {
check_dims_match_num_input_features("running_var", num_features, running_var.numel());
check_dims_match_num_input_features("running_var", num_features, running_var.sym_numel());
} else if (!training) {
AT_ERROR("running_var must be defined in evaluation mode");
}
if (weight.defined()) {
check_dims_match_num_input_features("weight", num_features, weight.numel());
check_dims_match_num_input_features("weight", num_features, weight.sym_numel());
}
if (bias.defined()) {
check_dims_match_num_input_features("bias", num_features, bias.numel());
check_dims_match_num_input_features("bias", num_features, bias.sym_numel());
}

const bool use_cudnn = (
Expand Down
2 changes: 1 addition & 1 deletion test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def create_symbolic_tensor(name, arg, shape_env):
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device)


CPP_SYMINT_CLASS = type(torch._C.SymIntNode.new_symint(1))
CPP_SYMINT_CLASS = type(torch.SymIntNode.new_symint(1))


@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
Expand Down
7 changes: 6 additions & 1 deletion torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def meta_fft_c2r(self, dim, normalization, lastdim):
return self.new_empty(output_sizes, dtype=toRealValueType(self.dtype))


@register_meta(aten.copy_.default, register_dispatcher=False)
def meta_copy_(self, src, non_blocking=False):
return self


# Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
@register_meta(aten.index_select.default)
def meta_index_select(self, dim, index):
Expand Down Expand Up @@ -327,7 +332,7 @@ def pick_memory_format():

else:
out_channels = weight.shape[0]
if weight.shape[1] != input_tensor.shape[1] / groups:
if weight.shape[1] * groups != input_tensor.shape[1]:
raise RuntimeError("Invalid channel dimensions")
shape_out = calc_conv_nd_return_shape(
dims, kernel_size, stride, padding, dilation
Expand Down
7 changes: 5 additions & 2 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import functools
import itertools
import sys
import warnings
import weakref
from dataclasses import dataclass
Expand Down Expand Up @@ -660,7 +661,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
return args[0].fake_device

flat_arg_tensors = tree_flatten_only(FakeTensor, (args, kwargs))
flat_symints = tree_flatten_only(torch._C.SymIntNode, (args, kwargs))
flat_symints = tree_flatten_only(torch.SymIntNode, (args, kwargs))
has_symbolic_sizes = (
any([i.has_sym_ints for i in flat_arg_tensors]) or len(flat_symints) > 0
)
Expand Down Expand Up @@ -730,10 +731,12 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if symbolic_shapes.is_symbolic_op(func):
return symbolic_shapes.handle_symbolic_op(func, args, kwargs)
if func == aten.size.default:
raise RuntimeError(
sys.stderr.write(
"Trying to call aten.size on a tensor with symbolic shapes. "
"It's likely that this is from calling tensor.shape in C++"
)
# We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1`
return None

with self.restore():
if func in meta_table:
Expand Down
6 changes: 3 additions & 3 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
SYM_FUNCTION_MODE = self.inner

def has_symbolic_sizes_strides(elem):
return any([isinstance(i, torch._C.SymIntNode) for i in elem.shape])
return any([isinstance(i, torch.SymIntNode) for i in elem.shape])

def create_contiguous(shape):
strides = [1]
Expand Down Expand Up @@ -189,7 +189,7 @@ def magic_impl(self, other):
return PySymInt(func(self.expr, other), self.shape_env)
return magic_impl

# this should be wrapped transparently into torch._C.SymIntNode
# this should be wrapped transparently into torch.SymIntNode
setattr(PySymInt, method, _create_magic_impl(_func))
setattr(PySymInt, f"__{method}__", _create_magic_impl(_func))
if method in reflectable_magic_methods:
Expand All @@ -210,7 +210,7 @@ def create_symint(self, name, val, shape_env=None):
return val
sympy_expr = sympy.Symbol(name, positive=True, integer=True)
py_sym_int = PySymInt(sympy_expr, self)
cpp_sym_int = torch._C.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
cpp_sym_int = torch.SymIntNode.new_symint(py_sym_int) # type: ignore[attr-defined]
shape_env[sympy_expr] = val
return cpp_sym_int

Expand Down
31 changes: 25 additions & 6 deletions torchgen/api/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,16 @@ class LazyArgument:
# TODO: this is lies, it is false for symint list
is_symint_or_list: bool

# Whether or not we are treating this as symint or not
symint: bool

# true if this argument is or contains a lazy IR value
is_lazy_value: bool

def __init__(self, arg: Argument, properties: "LazyIrProperties"):
def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool):
self.name = arg.name
self.orig_type = arg.type
self.symint = symint
self.is_optional = isinstance(arg.type, OptionalType)
self.is_generator = isGeneratorType(arg.type)
if self.is_generator:
Expand All @@ -222,7 +226,7 @@ def __init__(self, arg: Argument, properties: "LazyIrProperties"):
else:
self.lazy_type_ = process_ir_type(arg.type, properties)
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
self.is_symint_or_list = (
self.is_symint_or_list = symint and (
isSymIntType(arg.type)
or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
# TODO: lists of symints are not currently treated as value types
Expand Down Expand Up @@ -319,6 +323,12 @@ class LazyIrSchema:
# build a LazyArgument since lazy IR doesn't support it
generator_arg: Optional[NamedCType] = None

# original function schema
func: FunctionSchema

# Whether or not we are code-genning for SymInt or not
symint: bool

properties: LazyIrProperties = LazyIrProperties(
# default properties
"ShapePrecompute",
Expand All @@ -328,19 +338,27 @@ class LazyIrSchema:
opkind: Optional[str] = None

def __init__(
self, func: FunctionSchema, properties: Optional[LazyIrProperties] = None
self,
func: FunctionSchema,
properties: Optional[LazyIrProperties] = None,
*,
symint: bool,
):
if properties:
self.properties = properties

self.func = func
self.symint = symint
positional_args: List[LazyArgument] = []
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
if arg_field == "self_arg" and func.arguments.self_arg is not None:
arg = getattr(func.arguments, "self_arg").argument
positional_args.append(LazyArgument(arg, self.properties))
positional_args.append(
LazyArgument(arg, self.properties, symint=symint)
)
elif getattr(func.arguments, arg_field) is not None:
positional_args.extend(
LazyArgument(arg, self.properties)
LazyArgument(arg, self.properties, symint=symint)
for arg in getattr(func.arguments, arg_field)
)
self.positional_args = tuple(positional_args)
Expand All @@ -363,7 +381,8 @@ def __init__(
), "We expect there is only one generator arg"
self.generator_arg = NamedCType(arg.name, arg.type)
keyword_args.extend(
LazyArgument(arg, self.properties) for arg in curr_args
LazyArgument(arg, self.properties, symint=symint)
for arg in curr_args
)
self.keyword_args = tuple(keyword_args)
self.name = func.name
Expand Down
59 changes: 43 additions & 16 deletions torchgen/dest/lazy_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
deviceT,
DispatcherSignature,
kernel_signature,
NativeSignature,
OptionalCType,
VectorCType,
)
Expand All @@ -27,6 +28,7 @@
from torchgen.model import (
Argument,
BackendIndex,
BackendMetadata,
BaseTy,
BaseType,
FunctionSchema,
Expand Down Expand Up @@ -77,7 +79,10 @@ def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
BaseTy.SymInt
):
return f"GetSymIntArrayRefValue({arg.name})"
if arg.symint:
return f"GetSymIntArrayRefValue({arg.name})"
else:
return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
elif isinstance(arg.lazy_type, VectorCType) and isinstance(
arg.lazy_type.elem, BaseCType
):
Expand All @@ -102,13 +107,17 @@ def node_ctor_inputs(schema: LazyIrSchema) -> str:
return ", ".join(node_ctor_values)


def gen_fallback_code(schema: LazyIrSchema, overload_name: str) -> str:
def gen_fallback_code(
schema: LazyIrSchema,
sig: Union[DispatcherSignature, NativeSignature],
overload_name: str,
) -> str:
"""
Generate code that falls back to eager conditioned on a predicate
"""
fallback_args = ",\n ".join(
[str(arg.name) for arg in schema.filtered_args(generator=True)]
)
dispatcher_sig = DispatcherSignature.from_schema(schema.func)
exprs = translate(sig.arguments(), dispatcher_sig.arguments())
fallback_args = ",\n ".join([a.expr for a in exprs])
if len(overload_name):
aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
else:
Expand Down Expand Up @@ -167,7 +176,12 @@ class GenLazyIR(ABC):
@method_with_native_function
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
schema = LazyIrSchema(func)
metadata = self.backend_index.get_kernel(
f.functional if isinstance(f, NativeFunctionsGroup) else f
)
schema = LazyIrSchema(
func, symint=metadata is not None and metadata.supports_symint()
)
return self.gen(schema)

# there is no lowering functionality generated unless this IR base class is subclassed and
Expand Down Expand Up @@ -444,9 +458,17 @@ def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
)
return ("\n ").join(lazy_tensor_decls)

def force_eager_fallback(self, func: NativeFunction, schema: LazyIrSchema) -> str:
def force_eager_fallback(
self,
func: NativeFunction,
schema: LazyIrSchema,
metadata: BackendMetadata,
sig: Union[DispatcherSignature, NativeSignature],
) -> str:
if self.gen_forced_fallback_code:
return gen_fallback_code(schema, overload_name=func.func.name.overload_name)
return gen_fallback_code(
schema, sig, overload_name=func.func.name.overload_name
)
return ""

def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
Expand Down Expand Up @@ -525,7 +547,9 @@ def this_shape(i: int) -> str:
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
{meta_out}"""
else:
shape_sig = ComputeShapeSignature(metadata.kernel, func)
shape_sig = ComputeShapeSignature(
metadata.kernel, func, symint=metadata.supports_symint()
)
shape_str = f"""
auto shapes = {shape_sig.shape_call};"""

Expand Down Expand Up @@ -598,11 +622,11 @@ def __call__(self, func: NativeFunction) -> List[str]:
sig = kernel_signature(func, self.backend_index)
metadata = self.backend_index.get_kernel(func)
assert metadata is not None
schema = LazyIrSchema(func.func)
schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
return [
f"""\
{sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
{self.force_eager_fallback(func, schema)}
{self.force_eager_fallback(func, schema, metadata, sig)}
{self.metrics(func, schema)}
{self.get_device(func, schema)}
{self.lazy_tensor_decls(func, schema)}
Expand All @@ -618,10 +642,10 @@ class ComputeShapeSignature:
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
"""

def __init__(self, kernel_name: str, f: NativeFunction):
self.__schema = LazyIrSchema(f.func)
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool):
self.__schema = LazyIrSchema(f.func, symint=symint)
self.__dispatch_args = ", ".join(
[a.decl() for a in dispatcher.arguments(f.func)]
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
)
self.__call_args = ", ".join(
[f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
Expand Down Expand Up @@ -660,7 +684,9 @@ def __call__(self, f: NativeFunction) -> List[str]:
if is_structured or is_view_copy_op:
return []
else:
shape_sig = ComputeShapeSignature(metadata.kernel, f)
shape_sig = ComputeShapeSignature(
metadata.kernel, f, symint=metadata.supports_symint()
)
return ["\n".join([f"{shape_sig.shape_decl};"])]


Expand All @@ -675,7 +701,8 @@ def generate_non_native_lazy_ir_nodes(
for p in op.get("properties", []):
setattr(properties, p, True)

schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties)
# non-native is assumed to want symint bindings if you wrote symint
schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
schema.opkind = op.get("opkind")
nodes.append(gen_lazy_ir.gen(schema)[0])

Expand Down

0 comments on commit 90fe9e9

Please sign in to comment.