]> Cypherpunks.ru repositories - gocheese.git/blobdiff - gocheese.go
Mention about graceful shutdown different from cheeseshop
[gocheese.git] / gocheese.go
index 7e5f79ec0e86fd0ad599035050aa8cf2cdb993a5..7a6d9ffc85ab9e8d54f99386e30e3b2e6304e743 100644 (file)
@@ -4,8 +4,7 @@ Copyright (C) 2019 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
-the Free Software Foundation, either version 3 of the License, or
-(at your option) any later version.
+the Free Software Foundation, version 3 of the License.
 
 This program is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
@@ -21,6 +20,7 @@ package main
 
 import (
        "bytes"
+       "context"
        "crypto/sha256"
        "encoding/hex"
        "flag"
@@ -31,10 +31,13 @@ import (
        "net/http"
        "net/url"
        "os"
+       "os/signal"
        "path/filepath"
        "regexp"
        "runtime"
        "strings"
+       "syscall"
+       "time"
 )
 
 const (
@@ -47,8 +50,7 @@ const (
 
        Warranty = `This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
-the Free Software Foundation, either version 3 of the License, or
-(at your option) any later version.
+the Free Software Foundation, version 3 of the License.
 
 This program is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
@@ -60,26 +62,31 @@ along with this program.  If not, see <http://www.gnu.org/licenses/>.`
 )
 
 var (
-       root           = flag.String("root", "./packages", "Path to packages directory")
-       bind           = flag.String("bind", "[::]:8080", "Address to bind to")
-       simpleURLPath  = flag.String("simple", "/simple/", "/simple/ URL path")
-       refreshURLPath = flag.String("refresh", "/refresh/", "Auto-refreshing URL path")
-       pypiURL        = flag.String("pypi", "https://pypi.org/simple/", "Upstream PyPI URL")
-       auth           = flag.String("auth", "spam:foo", "login:password,...")
-       fsck           = flag.Bool("fsck", false, "Check integrity of all packages")
-       version        = flag.Bool("version", false, "Print version information")
-       warranty       = flag.Bool("warranty", false, "Print warranty information")
+       root             = flag.String("root", "./packages", "Path to packages directory")
+       bind             = flag.String("bind", "[::]:8080", "Address to bind to")
+       norefreshURLPath = flag.String("norefresh", "/norefresh/", "Non-refreshing URL path")
+       refreshURLPath   = flag.String("refresh", "/simple/", "Auto-refreshing URL path")
+       pypiURL          = flag.String("pypi", "https://pypi.org/simple/", "Upstream PyPI URL")
+       passwdPath       = flag.String("passwd", "passwd", "Path to file with authenticators")
+       passwdCheck      = flag.Bool("passwd-check", false, "Test the -passwd file for syntax errors and exit")
+       fsck             = flag.Bool("fsck", false, "Check integrity of all packages")
+       version          = flag.Bool("version", false, "Print version information")
+       warranty         = flag.Bool("warranty", false, "Print warranty information")
 
        pkgPyPI        = regexp.MustCompile(`^.*<a href="([^"]+)"[^>]*>(.+)</a><br/>.*$`)
        Version string = "UNKNOWN"
 
-       passwords map[string]string = make(map[string]string)
+       passwords map[string]Auther = make(map[string]Auther)
 )
 
+type Auther interface {
+       Auth(password string) bool
+}
+
 func mkdirForPkg(w http.ResponseWriter, r *http.Request, dir string) bool {
        path := filepath.Join(*root, dir)
        if _, err := os.Stat(path); os.IsNotExist(err) {
-               if err = os.Mkdir(path, 0700); err != nil {
+               if err = os.Mkdir(path, os.FileMode(0777)); err != nil {
                        http.Error(w, err.Error(), http.StatusInternalServerError)
                        return false
                }
@@ -182,7 +189,7 @@ func refreshDir(w http.ResponseWriter, r *http.Request, dir, filenameGet string)
                        }
                }
                log.Println(r.RemoteAddr, "pypi touch", filename)
-               if err = ioutil.WriteFile(path, digest, os.FileMode(0600)); err != nil {
+               if err = ioutil.WriteFile(path, digest, os.FileMode(0666)); err != nil {
                        http.Error(w, err.Error(), http.StatusInternalServerError)
                        return false
                }
@@ -202,7 +209,7 @@ func listRoot(w http.ResponseWriter, r *http.Request) {
                if file.Mode().IsDir() {
                        w.Write([]byte(fmt.Sprintf(
                                HTMLElement,
-                               *simpleURLPath+file.Name()+"/",
+                               *refreshURLPath+file.Name()+"/",
                                file.Name(),
                        )))
                }
@@ -241,7 +248,7 @@ func listDir(w http.ResponseWriter, r *http.Request, dir string, autorefresh boo
                w.Write([]byte(fmt.Sprintf(
                        HTMLElement,
                        strings.Join([]string{
-                               *simpleURLPath, dir, "/",
+                               *refreshURLPath, dir, "/",
                                filenameClean, "#", SHA256Prefix, string(data),
                        }, ""),
                        filenameClean,
@@ -263,7 +270,13 @@ func servePkg(w http.ResponseWriter, r *http.Request, dir, filename string) {
 
 func serveUpload(w http.ResponseWriter, r *http.Request) {
        username, password, ok := r.BasicAuth()
-       if !ok || passwords[username] != password {
+       if !ok {
+               log.Println(r.RemoteAddr, "unauthenticated", username)
+               http.Error(w, "unauthenticated", http.StatusUnauthorized)
+               return
+       }
+       auther, ok := passwords[username]
+       if !ok || !auther.Auth(password) {
                log.Println(r.RemoteAddr, "unauthenticated", username)
                http.Error(w, "unauthenticated", http.StatusUnauthorized)
                return
@@ -326,11 +339,7 @@ func serveUpload(w http.ResponseWriter, r *http.Request) {
                        http.Error(w, err.Error(), http.StatusInternalServerError)
                        return
                }
-               if err = ioutil.WriteFile(
-                       path+SHA256Ext,
-                       hasher.Sum(nil),
-                       os.FileMode(0600),
-               ); err != nil {
+               if err = ioutil.WriteFile(path+SHA256Ext, hasher.Sum(nil), os.FileMode(0666)); err != nil {
                        http.Error(w, err.Error(), http.StatusInternalServerError)
                        return
                }
@@ -341,8 +350,8 @@ func handler(w http.ResponseWriter, r *http.Request) {
        if r.Method == "GET" {
                var path string
                var autorefresh bool
-               if strings.HasPrefix(r.URL.Path, *simpleURLPath) {
-                       path = strings.TrimPrefix(r.URL.Path, *simpleURLPath)
+               if strings.HasPrefix(r.URL.Path, *norefreshURLPath) {
+                       path = strings.TrimPrefix(r.URL.Path, *norefreshURLPath)
                        autorefresh = false
                } else {
                        path = strings.TrimPrefix(r.URL.Path, *refreshURLPath)
@@ -427,15 +436,41 @@ func main() {
                }
                return
        }
-       for _, credentials := range strings.Split(*auth, ",") {
-               splitted := strings.Split(credentials, ":")
-               if len(splitted) != 2 {
-                       log.Fatal("Wrong auth format")
-               }
-               passwords[splitted[0]] = splitted[1]
+       if *passwdCheck {
+               refreshPasswd()
+               return
        }
+       refreshPasswd()
        log.Println("root:", *root, "bind:", *bind)
-       http.HandleFunc(*simpleURLPath, handler)
+       needsRefreshPasswd := make(chan os.Signal, 0)
+       needsShutdown := make(chan os.Signal, 0)
+       killed := make(chan error, 0)
+       http.HandleFunc(*norefreshURLPath, handler)
        http.HandleFunc(*refreshURLPath, handler)
-       log.Fatal(http.ListenAndServe(*bind, nil))
+       s := &http.Server{
+               Addr:           *bind,
+               ReadTimeout:    time.Minute,
+               WriteTimeout:   time.Minute,
+       }
+       signal.Notify(needsRefreshPasswd, syscall.SIGHUP)
+       signal.Notify(needsShutdown, syscall.SIGTERM, syscall.SIGINT)
+       go func() {
+               for range needsRefreshPasswd {
+                       log.Println("Refreshing passwords")
+                       refreshPasswd()
+               }
+       }()
+       go func(s *http.Server) {
+               <-needsShutdown
+               log.Println("Shutting down")
+               ctx, cancel := context.WithTimeout(context.TODO(), time.Minute)
+               killed <- s.Shutdown(ctx)
+               cancel()
+       }(s)
+       if err := s.ListenAndServe(); err != http.ErrServerClosed {
+               log.Fatal(err)
+       }
+       if err := <-killed; err != nil {
+               log.Fatal(err)
+       }
 }