@@ -10,6 +10,7 @@ import (
1010 "errors"
1111 "fmt"
1212 "io"
13+ "log/slog"
1314 "net/http"
1415 "os"
1516 "path/filepath"
@@ -52,20 +53,26 @@ func cacheDir() (string, error) {
5253var flight singleflight.Group
5354
5455// Verify the provided sha256 is valid.
55- func (r * Runner ) parseChecksum ( ) (string , error ) {
56- if r .SHA256 = = "" {
57- return "" , fmt . Errorf ( "missing SHA-256 checksum" )
56+ func (r * Runner ) getChecksum ( ctx context. Context ) (string , error ) {
57+ if r .SHA256 ! = "" {
58+ return r . SHA256 , nil
5859 }
59- return r .SHA256 , nil
60+ // TODO: Add a log line here about something
61+ _ , sum , err := r .fetch (ctx , r .URL )
62+ if err != nil {
63+ return "" , err
64+ }
65+ slog .Warn ("fetching WASM binary to calculate sha256. Set this value in sqlc.yaml to prevent unneeded work" , "sha256" , sum )
66+ return sum , nil
6067}
6168
6269func (r * Runner ) loadModule (ctx context.Context , engine * wasmtime.Engine ) (* wasmtime.Module , error ) {
63- expected , err := r .parseChecksum ( )
70+ expected , err := r .getChecksum ( ctx )
6471 if err != nil {
6572 return nil , err
6673 }
6774 value , err , _ := flight .Do (expected , func () (interface {}, error ) {
68- return r .loadSerializedModule (ctx , engine )
75+ return r .loadSerializedModule (ctx , engine , expected )
6976 })
7077 if err != nil {
7178 return nil , err
@@ -77,17 +84,13 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm
7784 return wasmtime .NewModuleDeserialize (engine , data )
7885}
7986
80- func (r * Runner ) loadSerializedModule (ctx context.Context , engine * wasmtime.Engine ) ([]byte , error ) {
81- expected , err := r .parseChecksum ()
82- if err != nil {
83- return nil , err
84- }
87+ func (r * Runner ) loadSerializedModule (ctx context.Context , engine * wasmtime.Engine , expectedSha string ) ([]byte , error ) {
8588 cacheDir , err := cache .PluginsDir ()
8689 if err != nil {
8790 return nil , err
8891 }
8992
90- pluginDir := filepath .Join (cacheDir , expected )
93+ pluginDir := filepath .Join (cacheDir , expectedSha )
9194 modName := fmt .Sprintf ("plugin_%s_%s_%s.module" , runtime .GOOS , runtime .GOARCH , wasmtimeVersion )
9295 modPath := filepath .Join (pluginDir , modName )
9396 _ , staterr := os .Stat (modPath )
@@ -99,7 +102,7 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
99102 return data , nil
100103 }
101104
102- wmod , err := r .loadWASM (ctx , cacheDir , expected )
105+ wmod , err := r .loadWASM (ctx , cacheDir , expectedSha )
103106 if err != nil {
104107 return nil , err
105108 }
@@ -126,53 +129,62 @@ func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engi
126129 return out , nil
127130}
128131
129- func (r * Runner ) loadWASM (ctx context.Context , cache string , expected string ) ([]byte , error ) {
130- pluginDir := filepath .Join (cache , expected )
131- pluginPath := filepath .Join (pluginDir , "plugin.wasm" )
132- _ , staterr := os .Stat (pluginPath )
133-
132+ func (r * Runner ) fetch (ctx context.Context , uri string ) ([]byte , string , error ) {
134133 var body io.ReadCloser
134+
135135 switch {
136- case staterr == nil :
137- file , err := os .Open (pluginPath )
138- if err != nil {
139- return nil , fmt .Errorf ("os.Open: %s %w" , pluginPath , err )
140- }
141- body = file
142136
143- case strings .HasPrefix (r . URL , "file://" ):
144- file , err := os .Open (strings .TrimPrefix (r . URL , "file://" ))
137+ case strings .HasPrefix (uri , "file://" ):
138+ file , err := os .Open (strings .TrimPrefix (uri , "file://" ))
145139 if err != nil {
146- return nil , fmt .Errorf ("os.Open: %s %w" , r . URL , err )
140+ return nil , "" , fmt .Errorf ("os.Open: %s %w" , uri , err )
147141 }
148142 body = file
149143
150- case strings .HasPrefix (r . URL , "https://" ):
151- req , err := http .NewRequestWithContext (ctx , "GET" , r . URL , nil )
144+ case strings .HasPrefix (uri , "https://" ):
145+ req , err := http .NewRequestWithContext (ctx , "GET" , uri , nil )
152146 if err != nil {
153- return nil , fmt .Errorf ("http.Get: %s %w" , r . URL , err )
147+ return nil , "" , fmt .Errorf ("http.Get: %s %w" , uri , err )
154148 }
155149 req .Header .Set ("User-Agent" , fmt .Sprintf ("sqlc/%s Go/%s (%s %s)" , info .Version , runtime .Version (), runtime .GOOS , runtime .GOARCH ))
156150 resp , err := http .DefaultClient .Do (req )
157151 if err != nil {
158- return nil , fmt .Errorf ("http.Get: %s %w" , r .URL , err )
152+ return nil , "" , fmt .Errorf ("http.Get: %s %w" , r .URL , err )
159153 }
160154 body = resp .Body
161155
162156 default :
163- return nil , fmt .Errorf ("unknown scheme: %s" , r .URL )
157+ return nil , "" , fmt .Errorf ("unknown scheme: %s" , r .URL )
164158 }
165159
166160 defer body .Close ()
167161
168162 wmod , err := io .ReadAll (body )
169163 if err != nil {
170- return nil , fmt .Errorf ("readall: %w" , err )
164+ return nil , "" , fmt .Errorf ("readall: %w" , err )
171165 }
172166
173167 sum := sha256 .Sum256 (wmod )
174168 actual := fmt .Sprintf ("%x" , sum )
175169
170+ return wmod , actual , nil
171+ }
172+
173+ func (r * Runner ) loadWASM (ctx context.Context , cache string , expected string ) ([]byte , error ) {
174+ pluginDir := filepath .Join (cache , expected )
175+ pluginPath := filepath .Join (pluginDir , "plugin.wasm" )
176+ _ , staterr := os .Stat (pluginPath )
177+
178+ uri := r .URL
179+ if staterr == nil {
180+ uri = "file://" + pluginPath
181+ }
182+
183+ wmod , actual , err := r .fetch (ctx , uri )
184+ if err != nil {
185+ return nil , err
186+ }
187+
176188 if expected != actual {
177189 return nil , fmt .Errorf ("invalid checksum: expected %s, got %s" , expected , actual )
178190 }
0 commit comments