@@ -14,11 +14,11 @@ import (
1414
1515// OutputColumns determines which columns a statement will output
1616func (c * Compiler ) OutputColumns (stmt ast.Node ) ([]* catalog.Column , error ) {
17- qc , err := buildQueryCatalog (c .catalog , stmt , nil )
17+ qc , err := c . buildQueryCatalog (c .catalog , stmt , nil )
1818 if err != nil {
1919 return nil , err
2020 }
21- cols , err := outputColumns (qc , stmt )
21+ cols , err := c . outputColumns (qc , stmt )
2222 if err != nil {
2323 return nil , err
2424 }
@@ -51,8 +51,8 @@ func hasStarRef(cf *ast.ColumnRef) bool {
5151//
5252// Return an error if column references are ambiguous
5353// Return an error if column references don't exist
54- func outputColumns (qc * QueryCatalog , node ast.Node ) ([]* Column , error ) {
55- tables , err := sourceTables (qc , node )
54+ func ( c * Compiler ) outputColumns (qc * QueryCatalog , node ast.Node ) ([]* Column , error ) {
55+ tables , err := c . sourceTables (qc , node )
5656 if err != nil {
5757 return nil , err
5858 }
@@ -68,21 +68,50 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
6868
6969 if n .GroupClause != nil {
7070 for _ , item := range n .GroupClause .Items {
71- ref , ok := item .(* ast.ColumnRef )
72- if ! ok {
73- continue
74- }
75-
76- if err := findColumnForRef (ref , tables , n ); err != nil {
71+ if err := findColumnForNode (item , tables , n ); err != nil {
7772 return nil , err
7873 }
7974 }
8075 }
76+ validateOrderBy := true
77+ if c .conf .StrictOrderBy != nil {
78+ validateOrderBy = * c .conf .StrictOrderBy
79+ }
80+ if validateOrderBy {
81+ if n .SortClause != nil {
82+ for _ , item := range n .SortClause .Items {
83+ sb , ok := item .(* ast.SortBy )
84+ if ! ok {
85+ continue
86+ }
87+ if err := findColumnForNode (sb .Node , tables , n ); err != nil {
88+ return nil , fmt .Errorf ("%v: if you want to skip this validation, set 'strict_order_by' to false" , err )
89+ }
90+ }
91+ }
92+ if n .WindowClause != nil {
93+ for _ , item := range n .WindowClause .Items {
94+ sb , ok := item .(* ast.List )
95+ if ! ok {
96+ continue
97+ }
98+ for _ , single := range sb .Items {
99+ caseExpr , ok := single .(* ast.CaseExpr )
100+ if ! ok {
101+ continue
102+ }
103+ if err := findColumnForNode (caseExpr .Xpr , tables , n ); err != nil {
104+ return nil , fmt .Errorf ("%v: if you want to skip this validation, set 'strict_order_by' to false" , err )
105+ }
106+ }
107+ }
108+ }
109+ }
81110
82111 // For UNION queries, targets is empty and we need to look for the
83112 // columns in Largs.
84113 if len (targets .Items ) == 0 && n .Larg != nil {
85- return outputColumns (qc , n .Larg )
114+ return c . outputColumns (qc , n .Larg )
86115 }
87116 case * ast.CallStmt :
88117 targets = & ast.List {}
@@ -303,7 +332,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
303332 case ast .EXISTS_SUBLINK :
304333 cols = append (cols , & Column {Name : name , DataType : "bool" , NotNull : true })
305334 case ast .EXPR_SUBLINK :
306- subcols , err := outputColumns (qc , n .Subselect )
335+ subcols , err := c . outputColumns (qc , n .Subselect )
307336 if err != nil {
308337 return nil , err
309338 }
@@ -339,7 +368,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
339368 cols = append (cols , col )
340369
341370 case * ast.SelectStmt :
342- subcols , err := outputColumns (qc , n )
371+ subcols , err := c . outputColumns (qc , n )
343372 if err != nil {
344373 return nil , err
345374 }
@@ -428,7 +457,7 @@ func isTableRequired(n ast.Node, col *Column, prior int) int {
428457// Return an error if column references don't exist
429458// Return an error if a table is referenced twice
430459// Return an error if an unknown column is referenced
431- func sourceTables (qc * QueryCatalog , node ast.Node ) ([]* Table , error ) {
460+ func ( c * Compiler ) sourceTables (qc * QueryCatalog , node ast.Node ) ([]* Table , error ) {
432461 var list * ast.List
433462 switch n := node .(type ) {
434463 case * ast.DeleteStmt :
@@ -483,7 +512,7 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
483512 tables = append (tables , table )
484513
485514 case * ast.RangeSubselect :
486- cols , err := outputColumns (qc , n .Subquery )
515+ cols , err := c . outputColumns (qc , n .Subquery )
487516 if err != nil {
488517 return nil , err
489518 }
@@ -581,6 +610,14 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
581610 return cols , nil
582611}
583612
613+ func findColumnForNode (item ast.Node , tables []* Table , n * ast.SelectStmt ) error {
614+ ref , ok := item .(* ast.ColumnRef )
615+ if ! ok {
616+ return nil
617+ }
618+ return findColumnForRef (ref , tables , n )
619+ }
620+
584621func findColumnForRef (ref * ast.ColumnRef , tables []* Table , selectStatement * ast.SelectStmt ) error {
585622 parts := stringSlice (ref .Fields )
586623 var alias , name string
0 commit comments