-
Notifications
You must be signed in to change notification settings - Fork 645
Expand file tree
/
Copy pathtest_utils.py
More file actions
37 lines (30 loc) · 1.26 KB
/
test_utils.py
File metadata and controls
37 lines (30 loc) · 1.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from unittest import TestCase
from pyspark.sql import SparkSession
from lib.utils import load_survey_df, count_by_country
class UtilsTestCase(TestCase):
spark = None
@classmethod
def setUpClass(cls) -> None:
cls.spark = SparkSession.builder \
.master("local[3]") \
.appName("HelloSparkTest") \
.getOrCreate()
@classmethod
def tearDownClass(cls) -> None:
cls.spark.stop()
def test_datafile_loading(self):
sample_df = load_survey_df(self.spark, "data/sample.csv")
result_count = sample_df.count()
self.assertEqual(result_count, 9, "Record count should be 9")
def test_country_count(self):
sample_df = load_survey_df(self.spark, "data/sample.csv")
count_list = count_by_country(sample_df).collect()
count_dict = dict()
for row in count_list:
count_dict[row["Country"]] = row["count"]
self.assertEqual(count_dict["United States"], 4, "Count for United States should be 4")
self.assertEqual(count_dict["Canada"], 2, "Count for Canada should be 2")
self.assertEqual(count_dict["United Kingdom"], 1, "Count for Unites Kingdom should be 1")
@classmethod
def tearDownClass(cls) -> None:
cls.spark.stop()