]> Cypherpunks.ru repositories - gocheese.git/blob - main.go
More convenient trusted-host
[gocheese.git] / main.go
1 // GoCheese -- Python private package repository and caching proxy
2 // Copyright (C) 2019-2024 Sergey Matveev <stargrave@stargrave.org>
3 //               2019-2024 Elena Balakhonova <balakhonova_e@riseup.net>
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, version 3 of the License.
8 //
9 // This program is distributed in the hope that it will be useful,
10 // but WITHOUT ANY WARRANTY; without even the implied warranty of
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 // GNU General Public License for more details.
13 //
14 // You should have received a copy of the GNU General Public License
15 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 // Python private package repository and caching proxy
18 package main
19
20 import (
21         "bytes"
22         "context"
23         "crypto/sha256"
24         "crypto/tls"
25         "encoding/hex"
26         "errors"
27         "flag"
28         "fmt"
29         "log"
30         "net"
31         "net/http"
32         "net/url"
33         "os"
34         "os/signal"
35         "path/filepath"
36         "runtime"
37         "strings"
38         "syscall"
39         "time"
40
41         "golang.org/x/net/netutil"
42 )
43
44 const (
45         Version   = "4.2.0"
46         UserAgent = "GoCheese/" + Version
47 )
48
49 var (
50         Root       string
51         Bind       = flag.String("bind", DefaultBind, "")
52         MaxClients = flag.Int("maxclients", DefaultMaxClients, "")
53         DoUCSPI    = flag.Bool("ucspi", false, "")
54
55         TLSCert = flag.String("tls-cert", "", "")
56         TLSKey  = flag.String("tls-key", "", "")
57
58         NoRefreshURLPath = flag.String("norefresh", DefaultNoRefreshURLPath, "")
59         RefreshURLPath   = flag.String("refresh", DefaultRefreshURLPath, "")
60         JSONURLPath      = flag.String("json", DefaultJSONURLPath, "")
61
62         PyPIURL      = flag.String("pypi", DefaultPyPIURL, "")
63         JSONURL      = flag.String("pypi-json", DefaultJSONURL, "")
64         PyPICertHash = flag.String("pypi-cert-hash", "", "")
65
66         PasswdPath     = flag.String("passwd", "", "")
67         PasswdListPath = flag.String("passwd-list", "", "")
68         PasswdCheck    = flag.Bool("passwd-check", false, "")
69         AuthRequired   = flag.Bool("auth-required", false, "")
70
71         LogTimestamped = flag.Bool("log-timestamped", false, "")
72         FSCK           = flag.Bool("fsck", false, "")
73         DoVersion      = flag.Bool("version", false, "")
74         DoWarranty     = flag.Bool("warranty", false, "")
75
76         Killed bool
77 )
78
79 func servePkg(w http.ResponseWriter, r *http.Request, pkgName, filename string) {
80         log.Println(r.RemoteAddr, "get", filename)
81         path := filepath.Join(Root, pkgName, filename)
82         if _, err := os.Stat(path); os.IsNotExist(err) {
83                 if !refreshDir(w, r, pkgName, filename) {
84                         return
85                 }
86         }
87         http.ServeFile(w, r, path)
88 }
89
90 func handler(w http.ResponseWriter, r *http.Request) {
91         w.Header().Set("Server", UserAgent)
92         switch r.Method {
93         case "GET":
94                 var path string
95                 var autorefresh bool
96                 if strings.HasPrefix(r.URL.Path, *NoRefreshURLPath) {
97                         path = strings.TrimPrefix(r.URL.Path, *NoRefreshURLPath)
98                 } else if strings.HasPrefix(r.URL.Path, *RefreshURLPath) {
99                         path = strings.TrimPrefix(r.URL.Path, *RefreshURLPath)
100                         autorefresh = true
101                 } else {
102                         http.Error(w, "unknown action", http.StatusBadRequest)
103                         return
104                 }
105                 parts := strings.Split(strings.TrimSuffix(path, "/"), "/")
106                 if len(parts) > 2 {
107                         http.Error(w, "invalid path", http.StatusBadRequest)
108                         return
109                 }
110                 if len(parts) == 1 {
111                         if parts[0] == "" {
112                                 listRoot(w, r)
113                         } else {
114                                 serveListDir(w, r, parts[0], autorefresh)
115                         }
116                 } else {
117                         servePkg(w, r, parts[0], parts[1])
118                 }
119         case "POST":
120                 serveUpload(w, r)
121         default:
122                 http.Error(w, "unknown action", http.StatusBadRequest)
123         }
124 }
125
126 func main() {
127         flag.Usage = usage
128         flag.Parse()
129         if *DoWarranty {
130                 fmt.Println(Warranty)
131                 return
132         }
133         if *DoVersion {
134                 fmt.Println("GoCheese", Version, "built with", runtime.Version())
135                 return
136         }
137
138         if *LogTimestamped {
139                 log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile)
140         } else {
141                 log.SetFlags(log.Lshortfile)
142         }
143         if !*DoUCSPI {
144                 log.SetOutput(os.Stdout)
145         }
146
147         if len(flag.Args()) != 1 {
148                 usage()
149                 os.Exit(1)
150         }
151         Root = flag.Args()[0]
152         if _, err := os.Stat(Root); err != nil {
153                 log.Fatal(err)
154         }
155
156         if *FSCK {
157                 if !goodIntegrity() {
158                         os.Exit(1)
159                 }
160                 return
161         }
162
163         if *PasswdCheck {
164                 if passwdReader(os.Stdin) {
165                         os.Exit(0)
166                 } else {
167                         os.Exit(1)
168                 }
169         }
170
171         if *PasswdPath != "" {
172                 go func() {
173                         for {
174                                 fd, err := os.OpenFile(
175                                         *PasswdPath,
176                                         os.O_RDONLY,
177                                         os.FileMode(0666),
178                                 )
179                                 if err != nil {
180                                         log.Fatal(err)
181                                 }
182                                 passwdReader(fd)
183                                 fd.Close()
184                         }
185                 }()
186         }
187         if *PasswdListPath != "" {
188                 go func() {
189                         for {
190                                 fd, err := os.OpenFile(
191                                         *PasswdListPath,
192                                         os.O_WRONLY|os.O_APPEND,
193                                         os.FileMode(0666),
194                                 )
195                                 if err != nil {
196                                         log.Fatal(err)
197                                 }
198                                 passwdLister(fd)
199                                 fd.Close()
200                         }
201                 }()
202         }
203
204         if (*TLSCert != "" && *TLSKey == "") || (*TLSCert == "" && *TLSKey != "") {
205                 log.Fatal("Both -tls-cert and -tls-key are required")
206         }
207
208         UmaskCur = syscall.Umask(0)
209         syscall.Umask(UmaskCur)
210
211         var err error
212         PyPIURLParsed, err = url.Parse(*PyPIURL)
213         if err != nil {
214                 log.Fatal(err)
215         }
216         tlsConfig := tls.Config{
217                 ClientSessionCache: tls.NewLRUClientSessionCache(16),
218                 NextProtos:         []string{"h2", "http/1.1"},
219         }
220         PyPIHTTPTransport = http.Transport{
221                 ForceAttemptHTTP2: true,
222                 TLSClientConfig:   &tlsConfig,
223         }
224         if *PyPICertHash != "" {
225                 ourDgst, err := hex.DecodeString(*PyPICertHash)
226                 if err != nil {
227                         log.Fatal(err)
228                 }
229                 tlsConfig.VerifyConnection = func(s tls.ConnectionState) error {
230                         spki := s.VerifiedChains[0][0].RawSubjectPublicKeyInfo
231                         theirDgst := sha256.Sum256(spki)
232                         if !bytes.Equal(ourDgst, theirDgst[:]) {
233                                 return errors.New("certificate's SPKI digest mismatch")
234                         }
235                         return nil
236                 }
237         }
238
239         server := &http.Server{
240                 ReadTimeout:  time.Minute,
241                 WriteTimeout: time.Minute,
242         }
243         http.HandleFunc("/", checkAuth(serveHRRoot))
244         http.HandleFunc("/hr/", checkAuth(serveHRPkg))
245         http.HandleFunc(*JSONURLPath, checkAuth(serveJSON))
246         http.HandleFunc(*NoRefreshURLPath, checkAuth(handler))
247         http.HandleFunc(*RefreshURLPath, checkAuth(handler))
248
249         if *DoUCSPI {
250                 server.SetKeepAlivesEnabled(false)
251                 ln := &UCSPI{}
252                 server.ConnState = connStater
253                 err := server.Serve(ln)
254                 if _, ok := err.(UCSPIAlreadyAccepted); !ok {
255                         log.Fatal(err)
256                 }
257                 UCSPIJob.Wait()
258                 return
259         }
260
261         ln, err := net.Listen("tcp", *Bind)
262         if err != nil {
263                 log.Fatal(err)
264         }
265         ln = netutil.LimitListener(ln, *MaxClients)
266
267         needsShutdown := make(chan os.Signal, 1)
268         exitErr := make(chan error)
269         signal.Notify(needsShutdown, syscall.SIGTERM, syscall.SIGINT)
270         go func(s *http.Server) {
271                 <-needsShutdown
272                 Killed = true
273                 log.Println("shutting down")
274                 ctx, cancel := context.WithTimeout(context.TODO(), time.Minute)
275                 exitErr <- s.Shutdown(ctx)
276                 cancel()
277         }(server)
278
279         log.Println(
280                 UserAgent, "ready:",
281                 "root:", Root,
282                 "bind:", *Bind,
283                 "pypi:", *PyPIURL,
284                 "json:", *JSONURL,
285         )
286         if *TLSCert == "" {
287                 err = server.Serve(ln)
288         } else {
289                 err = server.ServeTLS(ln, *TLSCert, *TLSKey)
290         }
291         if err != http.ErrServerClosed {
292                 log.Fatal(err)
293         }
294         if err := <-exitErr; err != nil {
295                 log.Fatal(err)
296         }
297 }