Skip to content

Commit 82c5c43

Browse files
authored
Mark pytorch tests as p100 exempt (#1544)
Newer version of pytorch is needed for rtx pros / cuda sm_120, but now it no longer works on P100s / cuda sm_60 https://chat.kaggle.net/kaggle/pl/mxgoz3gscf853gez3yhmhpw6hr
1 parent ccaff01 commit 82c5c43

3 files changed

Lines changed: 10 additions & 1 deletion

File tree

tests/test_fastai.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import fastai
44
from fastai.tabular.all import *
55

6+
from common import p100_exempt
7+
68

79
class TestFastAI(unittest.TestCase):
810
# Basic import
@@ -22,6 +24,7 @@ def test_torch_tensor(self):
2224

2325
self.assertTrue(torch.all(a == b))
2426

27+
@p100_exempt
2528
def test_tabular(self):
2629
dls = TabularDataLoaders.from_csv(
2730
"/input/tests/data/train.csv",

tests/test_pytorch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as tnn
55
import torch.autograd as autograd
66

7-
from common import gpu_test
7+
from common import gpu_test, p100_exempt
88

99

1010
class TestPyTorch(unittest.TestCase):
@@ -16,6 +16,7 @@ def test_nn(self):
1616
linear_torch(data_torch)
1717

1818
@gpu_test
19+
@p100_exempt
1920
def test_linalg(self):
2021
A = torch.randn(3, 3).t().to('cuda')
2122
B = torch.randn(3).t().to('cuda')
@@ -24,6 +25,7 @@ def test_linalg(self):
2425
self.assertEqual(3, result.shape[0])
2526

2627
@gpu_test
28+
@p100_exempt
2729
def test_gpu_computation(self):
2830
cuda = torch.device('cuda')
2931
a = torch.tensor([1., 2.], device=cuda)
@@ -33,6 +35,7 @@ def test_gpu_computation(self):
3335
self.assertEqual(torch.tensor([3.], device=cuda), result)
3436

3537
@gpu_test
38+
@p100_exempt
3639
def test_cuda_nn(self):
3740
# These throw if cuda is misconfigured
3841
tnn.GRUCell(10,10).cuda()

tests/test_pytorch_lightning.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import torch.nn.functional as F
66
from torch.utils.data import DataLoader, TensorDataset
77

8+
from common import p100_exempt
9+
810

911
class LitDataModule(pl.LightningDataModule):
1012

@@ -59,6 +61,7 @@ class TestPytorchLightning(unittest.TestCase):
5961
def test_version(self):
6062
self.assertIsNotNone(pl.__version__)
6163

64+
@p100_exempt
6265
def test_mnist(self):
6366
dm = LitDataModule()
6467
model = LitClassifier()

0 commit comments

Comments
 (0)