+/*
+GoCheese -- Python private package repository and caching proxy
+Copyright (C) 2019 Sergey Matveev <stargrave@stargrave.org>
+
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, version 3 of the License.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "crypto/md5"
+ "crypto/sha256"
+ "crypto/sha512"
+ "encoding/hex"
+ "hash"
+ "io"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "golang.org/x/crypto/blake2b"
+)
+
+func blake2b256New() hash.Hash {
+ h, err := blake2b.New256(nil)
+ if err != nil {
+ panic(err)
+ }
+ return h
+}
+
+func refreshDir(
+ w http.ResponseWriter,
+ r *http.Request,
+ dir,
+ filenameGet string,
+ gpgUpdate bool,
+) bool {
+ if _, err := os.Stat(filepath.Join(*root, dir, InternalFlag)); err == nil {
+ return true
+ }
+ resp, err := http.Get(*pypiURL + dir + "/")
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusBadGateway)
+ return false
+ }
+ body, err := ioutil.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusBadGateway)
+ return false
+ }
+ if !mkdirForPkg(w, r, dir) {
+ return false
+ }
+ dirPath := filepath.Join(*root, dir)
+ for _, lineRaw := range bytes.Split(body, []byte("\n")) {
+ submatches := pkgPyPI.FindStringSubmatch(string(lineRaw))
+ if len(submatches) == 0 {
+ continue
+ }
+ uri := submatches[1]
+ filename := submatches[2]
+ pkgURL, err := url.Parse(uri)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusBadGateway)
+ return false
+ }
+
+ if pkgURL.Fragment == "" {
+ log.Println(r.RemoteAddr, "pypi", filename, "no digest provided")
+ http.Error(w, "no digest provided", http.StatusBadGateway)
+ return false
+ }
+ digestInfo := strings.Split(pkgURL.Fragment, "=")
+ if len(digestInfo) == 1 {
+ // Ancient non PEP-0503 PyPIs, assume MD5
+ digestInfo = []string{"md5", digestInfo[0]}
+ } else if len(digestInfo) != 2 {
+ log.Println(r.RemoteAddr, "pypi", filename, "invalid digest provided")
+ http.Error(w, "invalid digest provided", http.StatusBadGateway)
+ return false
+ }
+ digest, err := hex.DecodeString(digestInfo[1])
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusBadGateway)
+ return false
+ }
+ hashAlgo := digestInfo[0]
+ var hasherNew func() hash.Hash
+ var hashSize int
+ switch hashAlgo {
+ case HashAlgoMD5:
+ hasherNew = md5.New
+ hashSize = md5.Size
+ case HashAlgoSHA256:
+ hasherNew = sha256.New
+ hashSize = sha256.Size
+ case HashAlgoSHA512:
+ hasherNew = sha512.New
+ hashSize = sha512.Size
+ case HashAlgoBLAKE2b256:
+ hasherNew = blake2b256New
+ hashSize = blake2b.Size256
+ default:
+ log.Println(
+ r.RemoteAddr, "pypi", filename,
+ "unknown digest algorithm", hashAlgo,
+ )
+ http.Error(w, "unknown digest algorithm", http.StatusBadGateway)
+ return false
+ }
+ if len(digest) != hashSize {
+ log.Println(r.RemoteAddr, "pypi", filename, "invalid digest length")
+ http.Error(w, "invalid digest length", http.StatusBadGateway)
+ return false
+ }
+
+ pkgURL.Fragment = ""
+ if pkgURL.Host == "" {
+ uri = pypiURLParsed.ResolveReference(pkgURL).String()
+ } else {
+ uri = pkgURL.String()
+ }
+
+ path := filepath.Join(dirPath, filename)
+ if filename == filenameGet {
+ if killed {
+ // Skip heavy remote call, when shutting down
+ http.Error(w, "shutting down", http.StatusInternalServerError)
+ return false
+ }
+ log.Println(r.RemoteAddr, "pypi download", filename)
+ resp, err = http.Get(uri)
+ if err != nil {
+ log.Println(r.RemoteAddr, "pypi download error:", err.Error())
+ http.Error(w, err.Error(), http.StatusBadGateway)
+ return false
+ }
+ defer resp.Body.Close()
+ hasher := hasherNew()
+ hasherSHA256 := sha256.New()
+ dst, err := TempFile(dirPath)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ dstBuf := bufio.NewWriter(dst)
+ wrs := []io.Writer{hasher, dstBuf}
+ if hashAlgo != HashAlgoSHA256 {
+ wrs = append(wrs, hasherSHA256)
+ }
+ wr := io.MultiWriter(wrs...)
+ if _, err = io.Copy(wr, resp.Body); err != nil {
+ os.Remove(dst.Name())
+ dst.Close()
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ if err = dstBuf.Flush(); err != nil {
+ os.Remove(dst.Name())
+ dst.Close()
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ if bytes.Compare(hasher.Sum(nil), digest) != 0 {
+ log.Println(r.RemoteAddr, "pypi", filename, "digest mismatch")
+ os.Remove(dst.Name())
+ dst.Close()
+ http.Error(w, "digest mismatch", http.StatusBadGateway)
+ return false
+ }
+ if err = dst.Sync(); err != nil {
+ os.Remove(dst.Name())
+ dst.Close()
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ if err = dst.Close(); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ if err = os.Rename(dst.Name(), path); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ if err = DirSync(dirPath); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ if hashAlgo != HashAlgoSHA256 {
+ hashAlgo = HashAlgoSHA256
+ digest = hasherSHA256.Sum(nil)
+ for _, algo := range knownHashAlgos[1:] {
+ os.Remove(path + "." + algo)
+ }
+ }
+ }
+ if filename == filenameGet || gpgUpdate {
+ if _, err = os.Stat(path); err != nil {
+ goto GPGSigSkip
+ }
+ resp, err := http.Get(uri + GPGSigExt)
+ if err != nil {
+ goto GPGSigSkip
+ }
+ if resp.StatusCode != http.StatusOK {
+ resp.Body.Close()
+ goto GPGSigSkip
+ }
+ sig, err := ioutil.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ goto GPGSigSkip
+ }
+ if !bytes.HasPrefix(sig, []byte("-----BEGIN PGP SIGNATURE-----")) {
+ log.Println(r.RemoteAddr, "pypi non PGP signature", filename)
+ goto GPGSigSkip
+ }
+ if err = WriteFileSync(dirPath, path+GPGSigExt, sig); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ log.Println(r.RemoteAddr, "pypi downloaded signature", filename)
+ }
+ GPGSigSkip:
+ path = path + "." + hashAlgo
+ _, err = os.Stat(path)
+ if err == nil {
+ continue
+ }
+ if !os.IsNotExist(err) {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ log.Println(r.RemoteAddr, "pypi touch", filename)
+ if err = WriteFileSync(dirPath, path, digest); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return false
+ }
+ }
+ return true
+}