Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tsai models should be TorchScriptable #561

Open
ivanzvonkov opened this issue Aug 16, 2022 · 4 comments
Open

tsai models should be TorchScriptable #561

ivanzvonkov opened this issue Aug 16, 2022 · 4 comments
Labels
enhancement New feature or request ideas New ideas to enhance tsai

Comments

@ivanzvonkov
Copy link

Context: A common way to deploy models is using TorchServe (https://pytorch.org/serve/). The simplest way to do this requires a TorchScripted model (https://pytorch.org/docs/stable/jit.html).

Issue: While many tsai model architectures are already TorchScriptable, some are not for various reasons. This makes it difficult to deploy these models for production use. For example below code shows that the TransformerModel is TorchScriptable while LSTM is not.

In [1]: import torch

In [2]: from tsai.models.RNN import LSTM

In [3]: from tsai.models.TransformerModel import TransformerModel

In [4]: optimized_transformer = torch.jit.script(TransformerModel(c_in=2, c_out=1))
<RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:240: UserWarning: 'batch_first' was found in ScriptModule constants, but was not actually set in __init__. Consider removing it.
  warnings.warn("'{}' was found in ScriptModule constants, "
<RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:234: UserWarning: 'norm' was found in ScriptModule constants,  but it is a non-constant submodule. Consider removing it.
  warnings.warn("'{}' was found in ScriptModule constants, "

In [5]: optimized_lstm = torch.jit.script(LSTM(c_in=2, c_out=1))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [5], in <cell line: 1>()
----> 1 optimized_lstm = torch.jit.script(LSTM(c_in=2, c_out=1))

File <RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_script.py:1265, in script(obj, optimize, _frames_up, _rcb, example_inputs)
   1263 if isinstance(obj, torch.nn.Module):
   1264     obj = call_prepare_scriptable_func(obj)
-> 1265     return torch.jit._recursive.create_script_module(
   1266         obj, torch.jit._recursive.infer_methods_to_compile
   1267     )
   1269 if isinstance(obj, dict):
   1270     return create_script_dict(obj)

File <RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:454, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
    452 if not is_tracing:
    453     AttributeTypeIsSupportedChecker().check(nn_module)
--> 454 return create_script_module_impl(nn_module, concrete_type, stubs_fn)

File <RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:520, in create_script_module_impl(nn_module, concrete_type, stubs_fn)
    518 # Compile methods if necessary
    519 if concrete_type not in concrete_type_store.methods_compiled:
--> 520     create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    521     # Create hooks after methods to ensure no name collisions between hooks and methods.
    522     # If done before, hooks can overshadow methods that aren't exported.
    523     create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs)

File <RETRACTED>/venv/lib/python3.9/site-packages/torch/jit/_recursive.py:371, in create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
    368 property_defs = [p.def_ for p in property_stubs]
    369 property_rcbs = [p.resolution_callback for p in property_stubs]
--> 371 concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)

RuntimeError:
Module 'LSTM' has no attribute 'dropout' (This function exists as an attribute on the Python module, but we failed to compile it to a TorchScript function.
The error stack is reproduced here:
Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "<RETRACTED>/venv/lib/python3.9/site-packages/fastai/imports.py", line 66
def noop (x=None, *args, **kwargs):
                          ~~~~~~~ <--- HERE
    "Do nothing"
    return x
:
  File "<RETRACTED>/venv/lib/python3.9/site-packages/tsai/models/RNN.py", line 21
        output, _ = self.rnn(x) # output from all sequence steps: [batch_size x seq_len x hidden_size * (1 + bidirectional)]
        output = output[:, -1]  # output from last sequence step : [batch_size x hidden_size * (1 + bidirectional)]
        output = self.fc(self.dropout(output))
                         ~~~~~~~~~~~~ <--- HERE
        return output

Potential solution:
For the LSTM (and other RNN models) the issue is explicitly with this line:

self.dropout = nn.Dropout(fc_dropout) if fc_dropout else noop

I believe substituting noop with nn.Identity() will keep existing behavior and make the model TorchScriptable.

Generally, perhaps adding an integration test to test all models for TorchScriptability would be a sensible first step. Then potentially individual model architectures could be addressed on a case by case basis.

@oguiza
Copy link
Contributor

oguiza commented Oct 21, 2022

Hi @ivanzvonkov ,
I fully agree with you. It's something I've always had in mind, and I'll start doing it from now onwards.
If you look at the RNN documentation you'll see I've added some tests and show how you can convert models to TorchScript and/or ONNX.
Are you aware of any other models that cannot be converted?

@oguiza oguiza added enhancement New feature or request ideas New ideas to enhance tsai labels Oct 21, 2022
@ivanzvonkov
Copy link
Author

Thank you for making the updates to the RNN!
I recall quite a few from the README were not TorchScriptable including: FCN, TCN, InceptionTime, Rocket, TST, TabTransformer. Is there any generalized ways of testing all models within this framework?

@oguiza
Copy link
Contributor

oguiza commented Dec 7, 2022

There isn't a way to test all models, but I'll add some functionality soon as it makes sense.

@ivanzvonkov
Copy link
Author

Great, having some guarantees about which models are torchsciptable will make it a lot easier to test different models in projects where deployment is required for full evaluation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request ideas New ideas to enhance tsai
Projects
None yet
Development

No branches or pull requests

2 participants