]> Cypherpunks.ru repositories - gocheese.git/blob - main.go
Fatal(err) is enough
[gocheese.git] / main.go
1 /*
2 GoCheese -- Python private package repository and caching proxy
3 Copyright (C) 2019-2023 Sergey Matveev <stargrave@stargrave.org>
4               2019-2023 Elena Balakhonova <balakhonova_e@riseup.net>
5
6 This program is free software: you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, version 3 of the License.
9
10 This program is distributed in the hope that it will be useful,
11 but WITHOUT ANY WARRANTY; without even the implied warranty of
12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License
16 along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19 // Python private package repository and caching proxy
20 package main
21
22 import (
23         "bytes"
24         "context"
25         "crypto/sha256"
26         "crypto/tls"
27         "encoding/hex"
28         "errors"
29         "flag"
30         "fmt"
31         "log"
32         "net"
33         "net/http"
34         "net/url"
35         "os"
36         "os/signal"
37         "path/filepath"
38         "runtime"
39         "strings"
40         "syscall"
41         "time"
42
43         "golang.org/x/net/netutil"
44 )
45
46 const (
47         Version   = "4.0.0"
48         UserAgent = "GoCheese/" + Version
49 )
50
51 var (
52         Root       string
53         Bind       = flag.String("bind", DefaultBind, "")
54         MaxClients = flag.Int("maxclients", DefaultMaxClients, "")
55         DoUCSPI    = flag.Bool("ucspi", false, "")
56
57         TLSCert = flag.String("tls-cert", "", "")
58         TLSKey  = flag.String("tls-key", "", "")
59
60         NoRefreshURLPath = flag.String("norefresh", DefaultNoRefreshURLPath, "")
61         RefreshURLPath   = flag.String("refresh", DefaultRefreshURLPath, "")
62         JSONURLPath      = flag.String("json", DefaultJSONURLPath, "")
63
64         PyPIURL      = flag.String("pypi", DefaultPyPIURL, "")
65         JSONURL      = flag.String("pypi-json", DefaultJSONURL, "")
66         PyPICertHash = flag.String("pypi-cert-hash", "", "")
67
68         PasswdPath     = flag.String("passwd", "", "")
69         PasswdListPath = flag.String("passwd-list", "", "")
70         PasswdCheck    = flag.Bool("passwd-check", false, "")
71
72         LogTimestamped = flag.Bool("log-timestamped", false, "")
73         FSCK           = flag.Bool("fsck", false, "")
74         DoVersion      = flag.Bool("version", false, "")
75         DoWarranty     = flag.Bool("warranty", false, "")
76
77         Killed bool
78 )
79
80 func servePkg(w http.ResponseWriter, r *http.Request, pkgName, filename string) {
81         log.Println(r.RemoteAddr, "get", filename)
82         path := filepath.Join(Root, pkgName, filename)
83         if _, err := os.Stat(path); os.IsNotExist(err) {
84                 if !refreshDir(w, r, pkgName, filename) {
85                         return
86                 }
87         }
88         http.ServeFile(w, r, path)
89 }
90
91 func handler(w http.ResponseWriter, r *http.Request) {
92         w.Header().Set("Server", UserAgent)
93         switch r.Method {
94         case "GET":
95                 var path string
96                 var autorefresh bool
97                 if strings.HasPrefix(r.URL.Path, *NoRefreshURLPath) {
98                         path = strings.TrimPrefix(r.URL.Path, *NoRefreshURLPath)
99                 } else if strings.HasPrefix(r.URL.Path, *RefreshURLPath) {
100                         path = strings.TrimPrefix(r.URL.Path, *RefreshURLPath)
101                         autorefresh = true
102                 } else {
103                         http.Error(w, "unknown action", http.StatusBadRequest)
104                         return
105                 }
106                 parts := strings.Split(strings.TrimSuffix(path, "/"), "/")
107                 if len(parts) > 2 {
108                         http.Error(w, "invalid path", http.StatusBadRequest)
109                         return
110                 }
111                 if len(parts) == 1 {
112                         if parts[0] == "" {
113                                 listRoot(w, r)
114                         } else {
115                                 serveListDir(w, r, parts[0], autorefresh)
116                         }
117                 } else {
118                         servePkg(w, r, parts[0], parts[1])
119                 }
120         case "POST":
121                 serveUpload(w, r)
122         default:
123                 http.Error(w, "unknown action", http.StatusBadRequest)
124         }
125 }
126
127 func main() {
128         flag.Usage = usage
129         flag.Parse()
130         if *DoWarranty {
131                 fmt.Println(Warranty)
132                 return
133         }
134         if *DoVersion {
135                 fmt.Println("GoCheese", Version, "built with", runtime.Version())
136                 return
137         }
138
139         if *LogTimestamped {
140                 log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile)
141         } else {
142                 log.SetFlags(log.Lshortfile)
143         }
144         if !*DoUCSPI {
145                 log.SetOutput(os.Stdout)
146         }
147
148         if len(flag.Args()) != 1 {
149                 usage()
150                 os.Exit(1)
151         }
152         Root = flag.Args()[0]
153         if _, err := os.Stat(Root); err != nil {
154                 log.Fatal(err)
155         }
156
157         if *FSCK {
158                 if !goodIntegrity() {
159                         os.Exit(1)
160                 }
161                 return
162         }
163
164         if *PasswdCheck {
165                 if passwdReader(os.Stdin) {
166                         os.Exit(0)
167                 } else {
168                         os.Exit(1)
169                 }
170         }
171
172         if *PasswdPath != "" {
173                 go func() {
174                         for {
175                                 fd, err := os.OpenFile(
176                                         *PasswdPath,
177                                         os.O_RDONLY,
178                                         os.FileMode(0666),
179                                 )
180                                 if err != nil {
181                                         log.Fatal(err)
182                                 }
183                                 passwdReader(fd)
184                                 fd.Close()
185                         }
186                 }()
187         }
188         if *PasswdListPath != "" {
189                 go func() {
190                         for {
191                                 fd, err := os.OpenFile(
192                                         *PasswdListPath,
193                                         os.O_WRONLY|os.O_APPEND,
194                                         os.FileMode(0666),
195                                 )
196                                 if err != nil {
197                                         log.Fatal(err)
198                                 }
199                                 passwdLister(fd)
200                                 fd.Close()
201                         }
202                 }()
203         }
204
205         if (*TLSCert != "" && *TLSKey == "") || (*TLSCert == "" && *TLSKey != "") {
206                 log.Fatal("Both -tls-cert and -tls-key are required")
207         }
208
209         UmaskCur = syscall.Umask(0)
210         syscall.Umask(UmaskCur)
211
212         var err error
213         PyPIURLParsed, err = url.Parse(*PyPIURL)
214         if err != nil {
215                 log.Fatal(err)
216         }
217         tlsConfig := tls.Config{
218                 ClientSessionCache: tls.NewLRUClientSessionCache(16),
219                 NextProtos:         []string{"h2", "http/1.1"},
220         }
221         PyPIHTTPTransport = http.Transport{
222                 ForceAttemptHTTP2: true,
223                 TLSClientConfig:   &tlsConfig,
224         }
225         if *PyPICertHash != "" {
226                 ourDgst, err := hex.DecodeString(*PyPICertHash)
227                 if err != nil {
228                         log.Fatal(err)
229                 }
230                 tlsConfig.VerifyConnection = func(s tls.ConnectionState) error {
231                         spki := s.VerifiedChains[0][0].RawSubjectPublicKeyInfo
232                         theirDgst := sha256.Sum256(spki)
233                         if !bytes.Equal(ourDgst, theirDgst[:]) {
234                                 return errors.New("certificate's SPKI digest mismatch")
235                         }
236                         return nil
237                 }
238         }
239
240         server := &http.Server{
241                 ReadTimeout:  time.Minute,
242                 WriteTimeout: time.Minute,
243         }
244         http.HandleFunc("/", serveHRRoot)
245         http.HandleFunc("/hr/", serveHRPkg)
246         http.HandleFunc(*JSONURLPath, serveJSON)
247         http.HandleFunc(*NoRefreshURLPath, handler)
248         http.HandleFunc(*RefreshURLPath, handler)
249
250         if *DoUCSPI {
251                 server.SetKeepAlivesEnabled(false)
252                 ln := &UCSPI{}
253                 server.ConnState = connStater
254                 err := server.Serve(ln)
255                 if _, ok := err.(UCSPIAlreadyAccepted); !ok {
256                         log.Fatal(err)
257                 }
258                 UCSPIJob.Wait()
259                 return
260         }
261
262         ln, err := net.Listen("tcp", *Bind)
263         if err != nil {
264                 log.Fatal(err)
265         }
266         ln = netutil.LimitListener(ln, *MaxClients)
267
268         needsShutdown := make(chan os.Signal, 1)
269         exitErr := make(chan error)
270         signal.Notify(needsShutdown, syscall.SIGTERM, syscall.SIGINT)
271         go func(s *http.Server) {
272                 <-needsShutdown
273                 Killed = true
274                 log.Println("shutting down")
275                 ctx, cancel := context.WithTimeout(context.TODO(), time.Minute)
276                 exitErr <- s.Shutdown(ctx)
277                 cancel()
278         }(server)
279
280         log.Println(
281                 UserAgent, "ready:",
282                 "root:", Root,
283                 "bind:", *Bind,
284                 "pypi:", *PyPIURL,
285                 "json:", *JSONURL,
286         )
287         if *TLSCert == "" {
288                 err = server.Serve(ln)
289         } else {
290                 err = server.ServeTLS(ln, *TLSCert, *TLSKey)
291         }
292         if err != http.ErrServerClosed {
293                 log.Fatal(err)
294         }
295         if err := <-exitErr; err != nil {
296                 log.Fatal(err)
297         }
298 }