Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does build_tabular_model accept TabTransformer? #358

Open
strakehyr opened this issue Jan 11, 2022 · 3 comments
Open

Does build_tabular_model accept TabTransformer? #358

strakehyr opened this issue Jan 11, 2022 · 3 comments
Labels
enhancement New feature or request

Comments

@strakehyr
Copy link

No description provided.

@strakehyr
Copy link
Author

Testing the build_tabular_model, I tried feeding it with TabTransformer and seemed to not be able to generate a model.
tab = build_tabular_model(TabTransformer, dls = dls_cat)
The error I got was:
TypeError: __init__() got an unexpected keyword argument 'y_range'

@radi-cho
Copy link

Hello. I am not a maintainer of the repository, but my PRs #365 and #362 are currently awaiting approval so I am getting familiar with the codebase. From https://github.com/timeseriesAI/tsai/blob/main/tsai/models/utils.py#L177 and https://github.com/timeseriesAI/tsai/blob/main/tsai/tslearner.py#L50 we can see that build_tabular_model and TSLearners are customized only for the TabularModel and not the TabTransformer. I am planning on opening a new pull request to improve the support for TabTransformer (and GatedTabTransformer) after my current PRs are finished. In the meantime you can support them with thumbs up :)

@oguiza
Copy link
Contributor

oguiza commented Jan 18, 2022

Hi @strakehyr and @radi-cho,
That's a good point. It'd be good to have TabTransformer supported by the build_tabular_model. build_tabular_model is just a convenience function. The workaround is to use this:

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names="salary",
    cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],
    cont_names = ['age', 'fnlwgt', 'education-num'],
    procs = [Categorify, FillMissing, Normalize])
model = TabTransformer(dls.classes, dls.cont_names, dls.c)
learn = Learner(dls, model)
learn.fit_one_cycle(1)

@oguiza oguiza added the enhancement New feature or request label Jan 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants