You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
But I get this shape error now after adding the above lines and running get_X_preds():
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], line 2
1 learn = load_learner('InceptionTime_default_40epochs')
----> 2 learn.get_X_preds(X)
File ~\miniconda3\lib\site-packages\tsai\inference.py:18, in get_X_preds(self, X, y, bs, with_input, with_decoded, with_loss)
16 with_loss = False
17 dl = self.dls.valid.new_dl(X, y=y, bs=bs)
---> 18 output = list(self.get_preds(dl=dl, with_input=with_input, with_decoded=with_decoded, with_loss=with_loss, reorder=False))
19 if with_decoded and len(self.dls.tls) >= 2 and hasattr(self.dls.tls[-1], "tfms") and hasattr(self.dls.tls[-1].tfms, "decodes"):
20 output[2 + with_input] = self.dls.tls[-1].tfms.decode(output[2 + with_input])
File ~\miniconda3\lib\site-packages\fastai\learner.py:308, in Learner.get_preds(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)
306 if with_loss: ctx_mgrs.append(self.loss_not_reduced())
307 with ContextManagers(ctx_mgrs):
--> 308 self._do_epoch_validate(dl=dl)
309 if act is None: act = getcallable(self.loss_func, 'activation')
310 res = cb.all_tensors()
File ~\miniconda3\lib\site-packages\fastai\learner.py:244, in Learner._do_epoch_validate(self, ds_idx, dl)
242 if dl is None: dl = self.dls[ds_idx]
243 self.dl = dl
--> 244 with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)
File ~\miniconda3\lib\site-packages\fastai\learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
198 def _with_events(self, f, event_type, ex, final=noop):
--> 199 try: self(f'before_{event_type}'); f()
200 except ex: self(f'after_cancel_{event_type}')
201 self(f'after_{event_type}'); final()
File ~\miniconda3\lib\site-packages\fastai\learner.py:205, in Learner.all_batches(self)
203 def all_batches(self):
204 self.n_iter = len(self.dl)
--> 205 for o in enumerate(self.dl): self.one_batch(*o)
File ~\miniconda3\lib\site-packages\tsai\learner.py:39, in one_batch(self, i, b)
37 b_on_device = to_device(b, device=self.dls.device) if self.dls.device is not None else b
38 self._split(b_on_device)
---> 39 self._with_events(self._do_one_batch, 'batch', CancelBatchException)
File ~\miniconda3\lib\site-packages\fastai\learner.py:199, in Learner._with_events(self, f, event_type, ex, final)
198 def _with_events(self, f, event_type, ex, final=noop):
--> 199 try: self(f'before_{event_type}'); f()
200 except ex: self(f'after_cancel_{event_type}')
201 self(f'after_{event_type}'); final()
File ~\miniconda3\lib\site-packages\fastai\learner.py:216, in Learner._do_one_batch(self)
215 def _do_one_batch(self):
--> 216 self.pred = self.model(*self.xb)
217 self('after_pred')
218 if len(self.yb):
File ~\miniconda3\lib\site-packages\torch\nn\modules\module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~\miniconda3\lib\site-packages\tsai\models\InceptionTime.py:67, in InceptionTime.forward(self, x)
65 def forward(self, x):
66 x = self.inceptionblock(x)
---> 67 x = self.gap(x)
68 x = self.fc(x)
69 return x
File ~\miniconda3\lib\site-packages\torch\nn\modules\module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~\miniconda3\lib\site-packages\tsai\models\layers.py:580, in GAP1d.forward(self, x)
579 def forward(self, x):
--> 580 return self.flatten(self.gap(x))
File ~\miniconda3\lib\site-packages\torch\nn\modules\module.py:1194, in Module._call_impl(self, *input, **kwargs)
1190 # If we don't have any hooks, we want to skip the rest of the logic in
1191 # this function, and just call forward.
1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
1196 full_backward_hooks, non_full_backward_hooks = [], []
File ~\miniconda3\lib\site-packages\tsai\models\layers.py:345, in Reshape.forward(self, x)
--> 345 def forward(self, x): return x.reshape(x.shape[0], *self.shape)
RuntimeError: shape '[1]' is invalid for input of size 128
The text was updated successfully, but these errors were encountered:
Versions in use:
IPython : 8.9.0
ipykernel : 6.19.2
ipywidgets : 7.6.5
jupyter_client : 7.4.9
jupyter_core : 5.2.0
jupyter_server : 1.23.4
jupyterlab : 3.5.3
nbclient : 0.5.13
nbconvert : 6.5.4
nbformat : 5.7.0
notebook : 6.5.2
qtconsole : 5.4.0
traitlets : 5.7.1
tsai 0.3.4
python 3.10.9
Files to replicate error:
test.csv https://github.com/epdavid1/xlpe/blob/main/test.csv
InceptionTime model https://github.com/epdavid1/xlpe/blob/main/InceptionTime_default_40epochs
Colab code (working):
Output:
Jupyter code (not working):
Output:
Now, to fix the 'PosixPath' error, a quick stackoverflow search showed I need to add this:
But I get this shape error now after adding the above lines and running get_X_preds():
The text was updated successfully, but these errors were encountered: