|
| 1 | +from datetime import datetime |
1 | 2 | from typing import List, Literal, Optional |
2 | 3 |
|
3 | | -from pydantic import BaseModel, Field |
| 4 | +from pydantic import BaseModel, Field, field_validator |
4 | 5 | from pytest import FixtureRequest |
5 | 6 |
|
6 | 7 | from pipelex.core.stuff_content import ListContent, StructuredContent, TextContent |
@@ -88,6 +89,38 @@ class ComplexListContent(ListContent[PersonContent]): |
88 | 89 | items: List[PersonContent] |
89 | 90 |
|
90 | 91 |
|
| 92 | +class GanttTaskDetails(StructuredContent): |
| 93 | + """Do not include timezone in the dates.""" |
| 94 | + |
| 95 | + name: str |
| 96 | + start_date: Optional[datetime] = None |
| 97 | + end_date: Optional[datetime] = None |
| 98 | + |
| 99 | + @field_validator("start_date", "end_date") |
| 100 | + @classmethod |
| 101 | + def remove_tzinfo(cls, v: Optional[datetime]) -> Optional[datetime]: |
| 102 | + if v is not None: |
| 103 | + return v.replace(tzinfo=None) |
| 104 | + return v |
| 105 | + |
| 106 | + |
| 107 | +class Milestone(StructuredContent): |
| 108 | + name: str |
| 109 | + date: Optional[datetime] |
| 110 | + |
| 111 | + @field_validator("date") |
| 112 | + @classmethod |
| 113 | + def remove_tzinfo(cls, v: Optional[datetime]) -> Optional[datetime]: |
| 114 | + if v is not None: |
| 115 | + return v.replace(tzinfo=None) |
| 116 | + return v |
| 117 | + |
| 118 | + |
| 119 | +class GanttChart(StructuredContent): |
| 120 | + tasks: Optional[List[GanttTaskDetails]] = None |
| 121 | + milestones: Optional[List[Milestone]] = None |
| 122 | + |
| 123 | + |
91 | 124 | class TestTypeInspector: |
92 | 125 | """Tests for the type inspector functionality""" |
93 | 126 |
|
@@ -268,3 +301,24 @@ def test_literal_field_content(self, request: FixtureRequest): |
268 | 301 | ' WORLD = "world"', |
269 | 302 | ] |
270 | 303 | assert result == expected, f"Expected:\n{''.join(expected)}\n\nGot:\n{''.join(result)}" |
| 304 | + |
| 305 | + def test_gantt_chart_content(self, request: FixtureRequest): |
| 306 | + """Test structure of Gantt chart content with datetime validators""" |
| 307 | + |
| 308 | + result = get_type_structure(GanttChart, base_class=StructuredContent) |
| 309 | + expected = [ |
| 310 | + "class GanttChart(StructuredContent):", |
| 311 | + " tasks: Optional[List[GanttTaskDetails]] = None", |
| 312 | + " milestones: Optional[List[Milestone]] = None", |
| 313 | + "", |
| 314 | + "class GanttTaskDetails(StructuredContent):", |
| 315 | + ' """Do not include timezone in the dates."""', |
| 316 | + " name: str", |
| 317 | + " start_date: Optional[datetime] = None", |
| 318 | + " end_date: Optional[datetime] = None", |
| 319 | + "", |
| 320 | + "class Milestone(StructuredContent):", |
| 321 | + " name: str", |
| 322 | + " date: Optional[datetime] = None", |
| 323 | + ] |
| 324 | + assert result == expected, f"Expected:\n{''.join(expected)}\n\nGot:\n{''.join(result)}" |
0 commit comments