]> Cypherpunks.ru repositories - gocheese.git/blob - main.go
More convenient trusted-host
[gocheese.git] / main.go
1 /*
2 GoCheese -- Python private package repository and caching proxy
3 Copyright (C) 2019-2022 Sergey Matveev <stargrave@stargrave.org>
4               2019-2022 Elena Balakhonova <balakhonova_e@riseup.net>
5
6 This program is free software: you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, version 3 of the License.
9
10 This program is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 // Python private package repository and caching proxy
20 package main
21
22 import (
23         "bytes"
24         "context"
25         "crypto/sha256"
26         "crypto/tls"
27         "encoding/hex"
28         "errors"
29         "flag"
30         "fmt"
31         "log"
32         "net"
33         "net/http"
34         "net/url"
35         "os"
36         "os/signal"
37         "path/filepath"
38         "runtime"
39         "strings"
40         "syscall"
41         "time"
42
43         "golang.org/x/net/netutil"
44 )
45
46 const (
47         Version   = "3.5.0"
48         UserAgent = "GoCheese/" + Version
49 )
50
51 var (
52         Root       string
53         Bind       = flag.String("bind", DefaultBind, "")
54         MaxClients = flag.Int("maxclients", DefaultMaxClients, "")
55         DoUCSPI    = flag.Bool("ucspi", false, "")
56
57         TLSCert = flag.String("tls-cert", "", "")
58         TLSKey  = flag.String("tls-key", "", "")
59
60         NoRefreshURLPath = flag.String("norefresh", DefaultNoRefreshURLPath, "")
61         RefreshURLPath   = flag.String("refresh", DefaultRefreshURLPath, "")
62         GPGUpdateURLPath = flag.String("gpgupdate", DefaultGPGUpdateURLPath, "")
63         JSONURLPath      = flag.String("json", DefaultJSONURLPath, "")
64
65         PyPIURL      = flag.String("pypi", DefaultPyPIURL, "")
66         JSONURL      = flag.String("pypi-json", DefaultJSONURL, "")
67         PyPICertHash = flag.String("pypi-cert-hash", "", "")
68
69         PasswdPath     = flag.String("passwd", "", "")
70         PasswdListPath = flag.String("passwd-list", "", "")
71         PasswdCheck    = flag.Bool("passwd-check", false, "")
72
73         LogTimestamped = flag.Bool("log-timestamped", false, "")
74         FSCK           = flag.Bool("fsck", false, "")
75         DoVersion      = flag.Bool("version", false, "")
76         DoWarranty     = flag.Bool("warranty", false, "")
77
78         Killed bool
79 )
80
81 func servePkg(w http.ResponseWriter, r *http.Request, pkgName, filename string) {
82         log.Println(r.RemoteAddr, "get", filename)
83         path := filepath.Join(Root, pkgName, filename)
84         if _, err := os.Stat(path); os.IsNotExist(err) {
85                 if !refreshDir(w, r, pkgName, filename, false) {
86                         return
87                 }
88         }
89         http.ServeFile(w, r, path)
90 }
91
92 func handler(w http.ResponseWriter, r *http.Request) {
93         w.Header().Set("Server", UserAgent)
94         switch r.Method {
95         case "GET":
96                 var path string
97                 var autorefresh bool
98                 var gpgUpdate bool
99                 if strings.HasPrefix(r.URL.Path, *NoRefreshURLPath) {
100                         path = strings.TrimPrefix(r.URL.Path, *NoRefreshURLPath)
101                 } else if strings.HasPrefix(r.URL.Path, *RefreshURLPath) {
102                         path = strings.TrimPrefix(r.URL.Path, *RefreshURLPath)
103                         autorefresh = true
104                 } else if strings.HasPrefix(r.URL.Path, *GPGUpdateURLPath) {
105                         path = strings.TrimPrefix(r.URL.Path, *GPGUpdateURLPath)
106                         autorefresh = true
107                         gpgUpdate = true
108                 } else {
109                         http.Error(w, "unknown action", http.StatusBadRequest)
110                         return
111                 }
112                 parts := strings.Split(strings.TrimSuffix(path, "/"), "/")
113                 if len(parts) > 2 {
114                         http.Error(w, "invalid path", http.StatusBadRequest)
115                         return
116                 }
117                 if len(parts) == 1 {
118                         if parts[0] == "" {
119                                 listRoot(w, r)
120                         } else {
121                                 serveListDir(w, r, parts[0], autorefresh, gpgUpdate)
122                         }
123                 } else {
124                         servePkg(w, r, parts[0], parts[1])
125                 }
126         case "POST":
127                 serveUpload(w, r)
128         default:
129                 http.Error(w, "unknown action", http.StatusBadRequest)
130         }
131 }
132
133 func main() {
134         flag.Usage = usage
135         flag.Parse()
136         if *DoWarranty {
137                 fmt.Println(Warranty)
138                 return
139         }
140         if *DoVersion {
141                 fmt.Println("GoCheese", Version, "built with", runtime.Version())
142                 return
143         }
144
145         if *LogTimestamped {
146                 log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile)
147         } else {
148                 log.SetFlags(log.Lshortfile)
149         }
150         if !*DoUCSPI {
151                 log.SetOutput(os.Stdout)
152         }
153
154         if len(flag.Args()) != 1 {
155                 usage()
156                 os.Exit(1)
157         }
158         Root = flag.Args()[0]
159         if _, err := os.Stat(Root); err != nil {
160                 log.Fatalln(err)
161         }
162
163         if *FSCK {
164                 if !goodIntegrity() {
165                         os.Exit(1)
166                 }
167                 return
168         }
169
170         if *PasswdCheck {
171                 if passwdReader(os.Stdin) {
172                         os.Exit(0)
173                 } else {
174                         os.Exit(1)
175                 }
176         }
177
178         if *PasswdPath != "" {
179                 go func() {
180                         for {
181                                 fd, err := os.OpenFile(
182                                         *PasswdPath,
183                                         os.O_RDONLY,
184                                         os.FileMode(0666),
185                                 )
186                                 if err != nil {
187                                         log.Fatalln(err)
188                                 }
189                                 passwdReader(fd)
190                                 fd.Close()
191                         }
192                 }()
193         }
194         if *PasswdListPath != "" {
195                 go func() {
196                         for {
197                                 fd, err := os.OpenFile(
198                                         *PasswdListPath,
199                                         os.O_WRONLY|os.O_APPEND,
200                                         os.FileMode(0666),
201                                 )
202                                 if err != nil {
203                                         log.Fatalln(err)
204                                 }
205                                 passwdLister(fd)
206                                 fd.Close()
207                         }
208                 }()
209         }
210
211         if (*TLSCert != "" && *TLSKey == "") || (*TLSCert == "" && *TLSKey != "") {
212                 log.Fatalln("Both -tls-cert and -tls-key are required")
213         }
214
215         var err error
216         PyPIURLParsed, err = url.Parse(*PyPIURL)
217         if err != nil {
218                 log.Fatalln(err)
219         }
220         tlsConfig := tls.Config{
221                 ClientSessionCache: tls.NewLRUClientSessionCache(16),
222                 NextProtos:         []string{"h2", "http/1.1"},
223         }
224         PyPIHTTPTransport = http.Transport{
225                 ForceAttemptHTTP2: true,
226                 TLSClientConfig:   &tlsConfig,
227         }
228         if *PyPICertHash != "" {
229                 ourDgst, err := hex.DecodeString(*PyPICertHash)
230                 if err != nil {
231                         log.Fatalln(err)
232                 }
233                 tlsConfig.VerifyConnection = func(s tls.ConnectionState) error {
234                         spki := s.VerifiedChains[0][0].RawSubjectPublicKeyInfo
235                         theirDgst := sha256.Sum256(spki)
236                         if bytes.Compare(ourDgst, theirDgst[:]) != 0 {
237                                 return errors.New("certificate's SPKI digest mismatch")
238                         }
239                         return nil
240                 }
241         }
242
243         server := &http.Server{
244                 ReadTimeout:  time.Minute,
245                 WriteTimeout: time.Minute,
246         }
247         http.HandleFunc("/", serveHRRoot)
248         http.HandleFunc("/hr/", serveHRPkg)
249         http.HandleFunc(*JSONURLPath, serveJSON)
250         http.HandleFunc(*NoRefreshURLPath, handler)
251         http.HandleFunc(*RefreshURLPath, handler)
252         if *GPGUpdateURLPath != "" {
253                 http.HandleFunc(*GPGUpdateURLPath, handler)
254         }
255
256         if *DoUCSPI {
257                 server.SetKeepAlivesEnabled(false)
258                 ln := &UCSPI{}
259                 server.ConnState = connStater
260                 err := server.Serve(ln)
261                 if _, ok := err.(UCSPIAlreadyAccepted); !ok {
262                         log.Fatalln(err)
263                 }
264                 UCSPIJob.Wait()
265                 return
266         }
267
268         ln, err := net.Listen("tcp", *Bind)
269         if err != nil {
270                 log.Fatal(err)
271         }
272         ln = netutil.LimitListener(ln, *MaxClients)
273
274         needsShutdown := make(chan os.Signal, 0)
275         exitErr := make(chan error, 0)
276         signal.Notify(needsShutdown, syscall.SIGTERM, syscall.SIGINT)
277         go func(s *http.Server) {
278                 <-needsShutdown
279                 Killed = true
280                 log.Println("shutting down")
281                 ctx, cancel := context.WithTimeout(context.TODO(), time.Minute)
282                 exitErr <- s.Shutdown(ctx)
283                 cancel()
284         }(server)
285
286         log.Println(
287                 UserAgent, "ready:",
288                 "root:", Root,
289                 "bind:", *Bind,
290                 "pypi:", *PyPIURL,
291                 "json:", *JSONURL,
292         )
293         if *TLSCert == "" {
294                 err = server.Serve(ln)
295         } else {
296                 err = server.ServeTLS(ln, *TLSCert, *TLSKey)
297         }
298         if err != http.ErrServerClosed {
299                 log.Fatal(err)
300         }
301         if err := <-exitErr; err != nil {
302                 log.Fatal(err)
303         }
304 }