@@ -585,6 +585,26 @@ func dataclassNode(name string) *pyast.ClassDef {
585585 }
586586}
587587
588+ func pydanticNode (name string ) * pyast.ClassDef {
589+ return & pyast.ClassDef {
590+ Name : name ,
591+ Bases : []* pyast.Node {
592+ {
593+ Node : & pyast.Node_Attribute {
594+ Attribute : & pyast.Attribute {
595+ Value : & pyast.Node {
596+ Node : & pyast.Node_Name {
597+ Name : & pyast.Name {Id : "pydantic" },
598+ },
599+ },
600+ Attr : "BaseModel" ,
601+ },
602+ },
603+ },
604+ },
605+ }
606+ }
607+
588608func fieldNode (f Field ) * pyast.Node {
589609 return & pyast.Node {
590610 Node : & pyast.Node_AnnAssign {
@@ -692,7 +712,12 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
692712 }
693713
694714 for _ , m := range ctx .Models {
695- def := dataclassNode (m .Name )
715+ var def * pyast.ClassDef
716+ if ctx .EmitPydanticModels {
717+ def = pydanticNode (m .Name )
718+ } else {
719+ def = dataclassNode (m .Name )
720+ }
696721 if m .Comment != "" {
697722 def .Body = append (def .Body , & pyast.Node {
698723 Node : & pyast.Node_Expr {
@@ -822,15 +847,25 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
822847 mod .Body = append (mod .Body , assignNode (q .ConstantName , poet .Constant (queryText )))
823848 for _ , arg := range q .Args {
824849 if arg .EmitStruct () {
825- def := dataclassNode (arg .Struct .Name )
850+ var def * pyast.ClassDef
851+ if ctx .EmitPydanticModels {
852+ def = pydanticNode (arg .Struct .Name )
853+ } else {
854+ def = dataclassNode (arg .Struct .Name )
855+ }
826856 for _ , f := range arg .Struct .Fields {
827857 def .Body = append (def .Body , fieldNode (f ))
828858 }
829859 mod .Body = append (mod .Body , poet .Node (def ))
830860 }
831861 }
832862 if q .Ret .EmitStruct () {
833- def := dataclassNode (q .Ret .Struct .Name )
863+ var def * pyast.ClassDef
864+ if ctx .EmitPydanticModels {
865+ def = pydanticNode (q .Ret .Struct .Name )
866+ } else {
867+ def = dataclassNode (q .Ret .Struct .Name )
868+ }
834869 for _ , f := range q .Ret .Struct .Fields {
835870 def .Body = append (def .Body , fieldNode (f ))
836871 }
@@ -1027,13 +1062,14 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
10271062}
10281063
10291064type pyTmplCtx struct {
1030- Models []Struct
1031- Queries []Query
1032- Enums []Enum
1033- EmitSync bool
1034- EmitAsync bool
1035- SourceName string
1036- SqlcVersion string
1065+ Models []Struct
1066+ Queries []Query
1067+ Enums []Enum
1068+ EmitSync bool
1069+ EmitAsync bool
1070+ SourceName string
1071+ SqlcVersion string
1072+ EmitPydanticModels bool
10371073}
10381074
10391075func (t * pyTmplCtx ) OutputQuery (sourceName string ) bool {
@@ -1060,12 +1096,13 @@ func Generate(req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
10601096 }
10611097
10621098 tctx := pyTmplCtx {
1063- Models : models ,
1064- Queries : queries ,
1065- Enums : enums ,
1066- EmitSync : req .Settings .Python .EmitSyncQuerier ,
1067- EmitAsync : req .Settings .Python .EmitAsyncQuerier ,
1068- SqlcVersion : req .SqlcVersion ,
1099+ Models : models ,
1100+ Queries : queries ,
1101+ Enums : enums ,
1102+ EmitSync : req .Settings .Python .EmitSyncQuerier ,
1103+ EmitAsync : req .Settings .Python .EmitAsyncQuerier ,
1104+ SqlcVersion : req .SqlcVersion ,
1105+ EmitPydanticModels : req .Settings .Python .EmitPydanticModels ,
10691106 }
10701107
10711108 output := map [string ]string {}
0 commit comments