]> Cypherpunks.ru repositories - gocheese.git/blobdiff - refresh.go
Explicitly check stored digest
[gocheese.git] / refresh.go
index 6fef8f600ea352d587d791607dc0ceaefe60fe65..0138e7ddfbf6f93e29e13f318ee025064ac6b640 100644 (file)
@@ -25,6 +25,7 @@ import (
        "crypto/sha512"
        "encoding/hex"
        "encoding/json"
+       "errors"
        "hash"
        "io"
        "io/ioutil"
@@ -338,12 +339,17 @@ func refreshDir(
                        hasherNew = blake2b256New
                        hashSize = blake2b.Size256
                default:
-                       log.Println("error", r.RemoteAddr, "pypi", filename, "unknown digest", hashAlgo)
+                       log.Println(
+                               "error", r.RemoteAddr, "pypi",
+                               filename, "unknown digest", hashAlgo,
+                       )
                        http.Error(w, "unknown digest algorithm", http.StatusBadGateway)
                        return false
                }
                if len(digest) != hashSize {
-                       log.Println("error", r.RemoteAddr, "pypi", filename, "invalid digest length")
+                       log.Println(
+                               "error", r.RemoteAddr, "pypi",
+                               filename, "invalid digest length")
                        http.Error(w, "invalid digest length", http.StatusBadGateway)
                        return false
                }
@@ -422,6 +428,15 @@ func refreshDir(
                                http.Error(w, "digest mismatch", http.StatusBadGateway)
                                return false
                        }
+                       if digestStored, err := ioutil.ReadFile(path + "." + hashAlgo); err == nil &&
+                               bytes.Compare(digest, digestStored) != 0 {
+                               err = errors.New("stored digest mismatch")
+                               log.Println("error", r.RemoteAddr, "pypi", filename, err)
+                               os.Remove(dst.Name())
+                               dst.Close()
+                               http.Error(w, err.Error(), http.StatusInternalServerError)
+                               return false
+                       }
                        if !NoSync {
                                if err = dst.Sync(); err != nil {
                                        os.Remove(dst.Name())
@@ -546,8 +561,8 @@ func refreshDir(
                }
                path = path + "." + hashAlgo
                stat, err := os.Stat(path)
-               if err == nil &&
-                       (mtimeExists && stat.ModTime().Truncate(time.Second).Equal(mtime)) {
+               if err == nil && (!mtimeExists ||
+                       (mtimeExists && stat.ModTime().Truncate(time.Second).Equal(mtime))) {
                        continue
                }
                if err != nil && !os.IsNotExist(err) {