]> Cypherpunks.ru repositories - gocheese.git/blob - main.go
Unify copyright comment format
[gocheese.git] / main.go
1 // GoCheese -- Python private package repository and caching proxy
2 // Copyright (C) 2019-2024 Sergey Matveev <stargrave@stargrave.org>
3 //               2019-2024 Elena Balakhonova <balakhonova_e@riseup.net>
4 //
5 // This program is free software: you can redistribute it and/or modify
6 // it under the terms of the GNU General Public License as published by
7 // the Free Software Foundation, version 3 of the License.
8 //
9 // This program is distributed in the hope that it will be useful,
10 // but WITHOUT ANY WARRANTY; without even the implied warranty of
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 // GNU General Public License for more details.
13 //
14 // You should have received a copy of the GNU General Public License
15 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
16
17 // Python private package repository and caching proxy
18 package main
19
20 import (
21         "bytes"
22         "context"
23         "crypto/sha256"
24         "crypto/tls"
25         "encoding/hex"
26         "errors"
27         "flag"
28         "fmt"
29         "log"
30         "net"
31         "net/http"
32         "net/url"
33         "os"
34         "os/signal"
35         "path/filepath"
36         "runtime"
37         "strings"
38         "syscall"
39         "time"
40
41         "golang.org/x/net/netutil"
42 )
43
44 const (
45         Version   = "4.1.0"
46         UserAgent = "GoCheese/" + Version
47 )
48
49 var (
50         Root       string
51         Bind       = flag.String("bind", DefaultBind, "")
52         MaxClients = flag.Int("maxclients", DefaultMaxClients, "")
53         DoUCSPI    = flag.Bool("ucspi", false, "")
54
55         TLSCert = flag.String("tls-cert", "", "")
56         TLSKey  = flag.String("tls-key", "", "")
57
58         NoRefreshURLPath = flag.String("norefresh", DefaultNoRefreshURLPath, "")
59         RefreshURLPath   = flag.String("refresh", DefaultRefreshURLPath, "")
60         JSONURLPath      = flag.String("json", DefaultJSONURLPath, "")
61
62         PyPIURL      = flag.String("pypi", DefaultPyPIURL, "")
63         JSONURL      = flag.String("pypi-json", DefaultJSONURL, "")
64         PyPICertHash = flag.String("pypi-cert-hash", "", "")
65
66         PasswdPath     = flag.String("passwd", "", "")
67         PasswdListPath = flag.String("passwd-list", "", "")
68         PasswdCheck    = flag.Bool("passwd-check", false, "")
69
70         LogTimestamped = flag.Bool("log-timestamped", false, "")
71         FSCK           = flag.Bool("fsck", false, "")
72         DoVersion      = flag.Bool("version", false, "")
73         DoWarranty     = flag.Bool("warranty", false, "")
74
75         Killed bool
76 )
77
78 func servePkg(w http.ResponseWriter, r *http.Request, pkgName, filename string) {
79         log.Println(r.RemoteAddr, "get", filename)
80         path := filepath.Join(Root, pkgName, filename)
81         if _, err := os.Stat(path); os.IsNotExist(err) {
82                 if !refreshDir(w, r, pkgName, filename) {
83                         return
84                 }
85         }
86         http.ServeFile(w, r, path)
87 }
88
89 func handler(w http.ResponseWriter, r *http.Request) {
90         w.Header().Set("Server", UserAgent)
91         switch r.Method {
92         case "GET":
93                 var path string
94                 var autorefresh bool
95                 if strings.HasPrefix(r.URL.Path, *NoRefreshURLPath) {
96                         path = strings.TrimPrefix(r.URL.Path, *NoRefreshURLPath)
97                 } else if strings.HasPrefix(r.URL.Path, *RefreshURLPath) {
98                         path = strings.TrimPrefix(r.URL.Path, *RefreshURLPath)
99                         autorefresh = true
100                 } else {
101                         http.Error(w, "unknown action", http.StatusBadRequest)
102                         return
103                 }
104                 parts := strings.Split(strings.TrimSuffix(path, "/"), "/")
105                 if len(parts) > 2 {
106                         http.Error(w, "invalid path", http.StatusBadRequest)
107                         return
108                 }
109                 if len(parts) == 1 {
110                         if parts[0] == "" {
111                                 listRoot(w, r)
112                         } else {
113                                 serveListDir(w, r, parts[0], autorefresh)
114                         }
115                 } else {
116                         servePkg(w, r, parts[0], parts[1])
117                 }
118         case "POST":
119                 serveUpload(w, r)
120         default:
121                 http.Error(w, "unknown action", http.StatusBadRequest)
122         }
123 }
124
125 func main() {
126         flag.Usage = usage
127         flag.Parse()
128         if *DoWarranty {
129                 fmt.Println(Warranty)
130                 return
131         }
132         if *DoVersion {
133                 fmt.Println("GoCheese", Version, "built with", runtime.Version())
134                 return
135         }
136
137         if *LogTimestamped {
138                 log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile)
139         } else {
140                 log.SetFlags(log.Lshortfile)
141         }
142         if !*DoUCSPI {
143                 log.SetOutput(os.Stdout)
144         }
145
146         if len(flag.Args()) != 1 {
147                 usage()
148                 os.Exit(1)
149         }
150         Root = flag.Args()[0]
151         if _, err := os.Stat(Root); err != nil {
152                 log.Fatal(err)
153         }
154
155         if *FSCK {
156                 if !goodIntegrity() {
157                         os.Exit(1)
158                 }
159                 return
160         }
161
162         if *PasswdCheck {
163                 if passwdReader(os.Stdin) {
164                         os.Exit(0)
165                 } else {
166                         os.Exit(1)
167                 }
168         }
169
170         if *PasswdPath != "" {
171                 go func() {
172                         for {
173                                 fd, err := os.OpenFile(
174                                         *PasswdPath,
175                                         os.O_RDONLY,
176                                         os.FileMode(0666),
177                                 )
178                                 if err != nil {
179                                         log.Fatal(err)
180                                 }
181                                 passwdReader(fd)
182                                 fd.Close()
183                         }
184                 }()
185         }
186         if *PasswdListPath != "" {
187                 go func() {
188                         for {
189                                 fd, err := os.OpenFile(
190                                         *PasswdListPath,
191                                         os.O_WRONLY|os.O_APPEND,
192                                         os.FileMode(0666),
193                                 )
194                                 if err != nil {
195                                         log.Fatal(err)
196                                 }
197                                 passwdLister(fd)
198                                 fd.Close()
199                         }
200                 }()
201         }
202
203         if (*TLSCert != "" && *TLSKey == "") || (*TLSCert == "" && *TLSKey != "") {
204                 log.Fatal("Both -tls-cert and -tls-key are required")
205         }
206
207         UmaskCur = syscall.Umask(0)
208         syscall.Umask(UmaskCur)
209
210         var err error
211         PyPIURLParsed, err = url.Parse(*PyPIURL)
212         if err != nil {
213                 log.Fatal(err)
214         }
215         tlsConfig := tls.Config{
216                 ClientSessionCache: tls.NewLRUClientSessionCache(16),
217                 NextProtos:         []string{"h2", "http/1.1"},
218         }
219         PyPIHTTPTransport = http.Transport{
220                 ForceAttemptHTTP2: true,
221                 TLSClientConfig:   &tlsConfig,
222         }
223         if *PyPICertHash != "" {
224                 ourDgst, err := hex.DecodeString(*PyPICertHash)
225                 if err != nil {
226                         log.Fatal(err)
227                 }
228                 tlsConfig.VerifyConnection = func(s tls.ConnectionState) error {
229                         spki := s.VerifiedChains[0][0].RawSubjectPublicKeyInfo
230                         theirDgst := sha256.Sum256(spki)
231                         if !bytes.Equal(ourDgst, theirDgst[:]) {
232                                 return errors.New("certificate's SPKI digest mismatch")
233                         }
234                         return nil
235                 }
236         }
237
238         server := &http.Server{
239                 ReadTimeout:  time.Minute,
240                 WriteTimeout: time.Minute,
241         }
242         http.HandleFunc("/", serveHRRoot)
243         http.HandleFunc("/hr/", serveHRPkg)
244         http.HandleFunc(*JSONURLPath, serveJSON)
245         http.HandleFunc(*NoRefreshURLPath, handler)
246         http.HandleFunc(*RefreshURLPath, handler)
247
248         if *DoUCSPI {
249                 server.SetKeepAlivesEnabled(false)
250                 ln := &UCSPI{}
251                 server.ConnState = connStater
252                 err := server.Serve(ln)
253                 if _, ok := err.(UCSPIAlreadyAccepted); !ok {
254                         log.Fatal(err)
255                 }
256                 UCSPIJob.Wait()
257                 return
258         }
259
260         ln, err := net.Listen("tcp", *Bind)
261         if err != nil {
262                 log.Fatal(err)
263         }
264         ln = netutil.LimitListener(ln, *MaxClients)
265
266         needsShutdown := make(chan os.Signal, 1)
267         exitErr := make(chan error)
268         signal.Notify(needsShutdown, syscall.SIGTERM, syscall.SIGINT)
269         go func(s *http.Server) {
270                 <-needsShutdown
271                 Killed = true
272                 log.Println("shutting down")
273                 ctx, cancel := context.WithTimeout(context.TODO(), time.Minute)
274                 exitErr <- s.Shutdown(ctx)
275                 cancel()
276         }(server)
277
278         log.Println(
279                 UserAgent, "ready:",
280                 "root:", Root,
281                 "bind:", *Bind,
282                 "pypi:", *PyPIURL,
283                 "json:", *JSONURL,
284         )
285         if *TLSCert == "" {
286                 err = server.Serve(ln)
287         } else {
288                 err = server.ServeTLS(ln, *TLSCert, *TLSKey)
289         }
290         if err != http.ErrServerClosed {
291                 log.Fatal(err)
292         }
293         if err := <-exitErr; err != nil {
294                 log.Fatal(err)
295         }
296 }