Skip to content

Commit 377a89c

Browse files
feature/make-from-str-default-native-text (Pipelex#78)
1 parent 805f3ee commit 377a89c

4 files changed

Lines changed: 63 additions & 2 deletions

File tree

pipelex/core/stuff_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,9 @@ def make_from_blueprint(cls, blueprint: StuffBlueprint) -> "Stuff":
9595
@classmethod
9696
def make_from_str(
9797
cls,
98-
concept_code: str,
9998
str_value: str,
10099
name: Optional[str] = None,
100+
concept_code: str = NativeConcept.TEXT.code,
101101
pipelex_session_id: Optional[str] = None,
102102
) -> Stuff:
103103
if not Concept.concept_str_contains_domain(concept_code):

pipelex/pipe_operators/pipe_llm_prompt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def get_output_structure_prompt(output_concept: str) -> str:
234234
"You do NOT need to output a formatted JSON object, another LLM will take care of that. "
235235
"If you cannot find a value that is Optional, output None for that field."
236236
"However, you MUST clearly output the values for each of these fields in your response.\n---\n"
237+
"DO NOT create information. If the information is not present, output None."
237238
)
238239
return output_structure_prompt
239240

pipelex/tools/typing/type_inspector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,19 @@ def collect_types(tp: Type[Any]) -> None:
120120
for arg in non_none:
121121
if isinstance(arg, type):
122122
collect_types(arg)
123+
elif hasattr(arg, "__origin__"): # Handle nested generics
124+
collect_types(arg)
123125
elif origin in (list, List):
124126
if isinstance(args[0], type):
125127
collect_types(args[0])
128+
elif hasattr(args[0], "__origin__"): # Handle nested generics
129+
collect_types(args[0])
126130
elif origin in (dict, Dict):
127131
for arg in args:
128132
if isinstance(arg, type):
129133
collect_types(arg)
134+
elif hasattr(arg, "__origin__"): # Handle nested generics
135+
collect_types(arg)
130136
return
131137

132138
# Collect enums

tests/tools/typing/test_type_inspector.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from datetime import datetime
12
from typing import List, Literal, Optional
23

3-
from pydantic import BaseModel, Field
4+
from pydantic import BaseModel, Field, field_validator
45
from pytest import FixtureRequest
56

67
from pipelex.core.stuff_content import ListContent, StructuredContent, TextContent
@@ -88,6 +89,38 @@ class ComplexListContent(ListContent[PersonContent]):
8889
items: List[PersonContent]
8990

9091

92+
class GanttTaskDetails(StructuredContent):
93+
"""Do not include timezone in the dates."""
94+
95+
name: str
96+
start_date: Optional[datetime] = None
97+
end_date: Optional[datetime] = None
98+
99+
@field_validator("start_date", "end_date")
100+
@classmethod
101+
def remove_tzinfo(cls, v: Optional[datetime]) -> Optional[datetime]:
102+
if v is not None:
103+
return v.replace(tzinfo=None)
104+
return v
105+
106+
107+
class Milestone(StructuredContent):
108+
name: str
109+
date: Optional[datetime]
110+
111+
@field_validator("date")
112+
@classmethod
113+
def remove_tzinfo(cls, v: Optional[datetime]) -> Optional[datetime]:
114+
if v is not None:
115+
return v.replace(tzinfo=None)
116+
return v
117+
118+
119+
class GanttChart(StructuredContent):
120+
tasks: Optional[List[GanttTaskDetails]] = None
121+
milestones: Optional[List[Milestone]] = None
122+
123+
91124
class TestTypeInspector:
92125
"""Tests for the type inspector functionality"""
93126

@@ -268,3 +301,24 @@ def test_literal_field_content(self, request: FixtureRequest):
268301
' WORLD = "world"',
269302
]
270303
assert result == expected, f"Expected:\n{''.join(expected)}\n\nGot:\n{''.join(result)}"
304+
305+
def test_gantt_chart_content(self, request: FixtureRequest):
306+
"""Test structure of Gantt chart content with datetime validators"""
307+
308+
result = get_type_structure(GanttChart, base_class=StructuredContent)
309+
expected = [
310+
"class GanttChart(StructuredContent):",
311+
" tasks: Optional[List[GanttTaskDetails]] = None",
312+
" milestones: Optional[List[Milestone]] = None",
313+
"",
314+
"class GanttTaskDetails(StructuredContent):",
315+
' """Do not include timezone in the dates."""',
316+
" name: str",
317+
" start_date: Optional[datetime] = None",
318+
" end_date: Optional[datetime] = None",
319+
"",
320+
"class Milestone(StructuredContent):",
321+
" name: str",
322+
" date: Optional[datetime] = None",
323+
]
324+
assert result == expected, f"Expected:\n{''.join(expected)}\n\nGot:\n{''.join(result)}"

0 commit comments

Comments
 (0)