]> Cypherpunks.ru repositories - nncp.git/commitdiff
MTH memory usage optimization
authorSergey Matveev <stargrave@stargrave.org>
Sun, 11 Jul 2021 13:55:12 +0000 (16:55 +0300)
committerSergey Matveev <stargrave@stargrave.org>
Sun, 11 Jul 2021 18:48:44 +0000 (21:48 +0300)
doc/news.ru.texi
doc/news.texi
src/check.go
src/cmd/nncp-hash/main.go
src/mth.go
src/mth_test.go
src/sp.go

index cb70cd861c84a6a015599ef84efaf1357fa94f0e..82336d06d03addbf183868d5fac51aea330079cb 100644 (file)
@@ -62,6 +62,9 @@
 Исправлена проблема с возможно остающимся открытым файловым
 дескриптором в online командах.
 
+@item
+Существенно снижено потребление памяти MTH хэширования.
+
 @end itemize
 
 @node Релиз 7.3.0
index 3be46521997af74df8fe203d5ac9e849a3337358..07359030c4580dc2c3a619859c18c9f8ede70297 100644 (file)
@@ -63,6 +63,9 @@ See also this page @ref{Новости, on russian}.
 @item
 Fixed possibly left opened file descriptor in online commands.
 
+@item
+Severely decreased memory usage of MTH hashing.
+
 @end itemize
 
 @node Release 7_3_0
index 7a3d5b8041a42a11059b763a5a2ac3e285c54646..d6a98fac805f3bcbff2b84af5da4ef804424d56c 100644 (file)
@@ -103,8 +103,10 @@ func (ctx *Ctx) CheckNoCK(nodeId *NodeId, hshValue *[MTHSize]byte, mth MTH) (int
        if mth == nil {
                gut, err = Check(fd, size, hshValue[:], les, ctx.ShowPrgrs)
        } else {
-               mth.SetPktName(pktName)
-               if _, err = mth.PrependFrom(bufio.NewReaderSize(fd, MTHSize)); err != nil {
+               if _, err = mth.PreaddFrom(
+                       bufio.NewReaderSize(fd, MTHSize),
+                       pktName, ctx.ShowPrgrs,
+               ); err != nil {
                        return 0, err
                }
                if bytes.Compare(mth.Sum(nil), hshValue[:]) == 0 {
index 8b90ab9c0cc5eda0dbf6e1b3bcf309f96fc6c4f7..bcdbca54fe4786cbdb94b19acf1cbf1cfc527646 100644 (file)
@@ -76,64 +76,71 @@ func main() {
                }
                size = fi.Size()
        }
-       var mth nncp.MTH
-       if *forceFat {
-               mth = nncp.MTHFatNew(size, int64(*seek))
-       } else {
-               mth = nncp.MTHNew(size, int64(*seek))
-       }
-       var debugger sync.WaitGroup
+
        if *debug {
                fmt.Println("Leaf BLAKE3 key:", hex.EncodeToString(nncp.MTHLeafKey[:]))
                fmt.Println("Node BLAKE3 key:", hex.EncodeToString(nncp.MTHNodeKey[:]))
-               events := mth.Events()
+       }
+
+       var debugger sync.WaitGroup
+       startDebug := func(events chan nncp.MTHEvent) {
                debugger.Add(1)
                go func() {
                        for e := range events {
-                               var t string
-                               switch e.Type {
-                               case nncp.MTHEventAppend:
-                                       t = "Add"
-                               case nncp.MTHEventPrepend:
-                                       t = "Pre"
-                               case nncp.MTHEventFold:
-                                       t = "Fold"
-                               }
-                               fmt.Printf(
-                                       "%s\t%03d\t%06d\t%s\n",
-                                       t, e.Level, e.Ctr, hex.EncodeToString(e.Hsh),
-                               )
+                               fmt.Println(e.String())
                        }
                        debugger.Done()
                }()
        }
-       if *seek != 0 {
-               if *fn == "" {
-                       log.Fatalln("-file is required with -seek")
+       copier := func(w io.Writer) error {
+               _, err := nncp.CopyProgressed(
+                       w, bufio.NewReaderSize(fd, nncp.MTHBlockSize), "hash",
+                       nncp.LEs{{K: "Pkt", V: *fn}, {K: "FullSize", V: size - int64(*seek)}},
+                       *showPrgrs,
+               )
+               return err
+       }
+
+       var sum []byte
+       if *forceFat {
+               mth := nncp.MTHFatNew()
+               if *debug {
+                       startDebug(mth.Events())
+
                }
-               if _, err = fd.Seek(int64(*seek), io.SeekStart); err != nil {
+               if err = copier(mth); err != nil {
                        log.Fatalln(err)
                }
-       }
-       if _, err = nncp.CopyProgressed(
-               mth, bufio.NewReaderSize(fd, nncp.MTHBlockSize),
-               "hash", nncp.LEs{{K: "Pkt", V: *fn}, {K: "FullSize", V: size - int64(*seek)}},
-               *showPrgrs,
-       ); err != nil {
-               log.Fatalln(err)
-       }
-       if *seek != 0 {
-               if _, err = fd.Seek(0, io.SeekStart); err != nil {
-                       log.Fatalln(err)
+               sum = mth.Sum(nil)
+       } else {
+               mth := nncp.MTHSeqNew(size, int64(*seek))
+               if *debug {
+                       startDebug(mth.Events())
                }
-               if *showPrgrs {
-                       mth.SetPktName(*fn)
+               if *seek != 0 {
+                       if *fn == "" {
+                               log.Fatalln("-file is required with -seek")
+                       }
+                       if _, err = fd.Seek(int64(*seek), io.SeekStart); err != nil {
+                               log.Fatalln(err)
+                       }
                }
-               if _, err = mth.PrependFrom(bufio.NewReaderSize(fd, nncp.MTHBlockSize)); err != nil {
+               if err = copier(mth); err != nil {
                        log.Fatalln(err)
                }
+               if *seek != 0 {
+                       if _, err = fd.Seek(0, io.SeekStart); err != nil {
+                               log.Fatalln(err)
+                       }
+                       if _, err = mth.PreaddFrom(
+                               bufio.NewReaderSize(fd, nncp.MTHBlockSize),
+                               *fn, *showPrgrs,
+                       ); err != nil {
+                               log.Fatalln(err)
+                       }
+               }
+               sum = mth.Sum(nil)
        }
-       sum := mth.Sum(nil)
        debugger.Wait()
        fmt.Println(hex.EncodeToString(sum))
 }
index d6e3082e3f6ee11beac1d2eea0226fc5b7e468e9..031c1cd8ee3ea5c3a171b48b736f67975a531908 100644 (file)
@@ -19,7 +19,9 @@ package nncp
 
 import (
        "bytes"
+       "encoding/hex"
        "errors"
+       "fmt"
        "hash"
        "io"
 
@@ -36,87 +38,168 @@ var (
        MTHNodeKey = blake3.Sum256([]byte("NNCP MTH NODE"))
 )
 
-type MTHEventType uint8
+type MTHSeqEnt struct {
+       l int
+       c int64
+       h [MTHSize]byte
+}
+
+func (ent *MTHSeqEnt) String() string {
+       return fmt.Sprintf("%03d\t%06d\t%s", ent.l, ent.c, hex.EncodeToString(ent.h[:]))
+}
+
+type MTHEventType string
 
 const (
-       MTHEventAppend  MTHEventType = iota
-       MTHEventPrepend MTHEventType = iota
-       MTHEventFold    MTHEventType = iota
+       MTHEventAdd    MTHEventType = "Add"
+       MTHEventPreadd MTHEventType = "Pre"
+       MTHEventFold   MTHEventType = "Fold"
 )
 
 type MTHEvent struct {
-       Type  MTHEventType
-       Level int64
-       Ctr   int64
-       Hsh   []byte
+       Type MTHEventType
+       Ent  *MTHSeqEnt
+}
+
+func (e MTHEvent) String() string {
+       return fmt.Sprintf("%s\t%s", e.Type, e.Ent.String())
 }
 
 type MTH interface {
        hash.Hash
-       PrependFrom(r io.Reader) (int64, error)
-       SetPktName(n string)
-       PrependSize() int64
+       PreaddFrom(r io.Reader, pktName string, showPrgrs bool) (int64, error)
+       PreaddSize() int64
        Events() chan MTHEvent
 }
 
-type MTHFat struct {
+type MTHSeq struct {
+       hasherLeaf  *blake3.Hasher
+       hasherNode  *blake3.Hasher
+       hashes      []MTHSeqEnt
+       buf         *bytes.Buffer
+       events      chan MTHEvent
+       ctr         int64
        size        int64
        prependSize int64
-       skip        int64
+       toSkip      int64
        skipped     bool
-       hasher      *blake3.Hasher
-       hashes      [][MTHSize]byte
-       buf         *bytes.Buffer
        finished    bool
-       events      chan MTHEvent
        pktName     string
 }
 
-func MTHFatNew(size, offset int64) MTH {
-       mth := MTHFat{
-               hasher: blake3.New(MTHSize, MTHLeafKey[:]),
-               buf:    bytes.NewBuffer(make([]byte, 0, 2*MTHBlockSize)),
+func MTHSeqNew(size, offset int64) *MTHSeq {
+       mth := MTHSeq{
+               hasherLeaf: blake3.New(MTHSize, MTHLeafKey[:]),
+               hasherNode: blake3.New(MTHSize, MTHNodeKey[:]),
+               buf:        bytes.NewBuffer(make([]byte, 0, 2*MTHBlockSize)),
        }
        if size == 0 {
                return &mth
        }
        prepends := offset / MTHBlockSize
-       skip := MTHBlockSize - (offset - prepends*MTHBlockSize)
-       if skip == MTHBlockSize {
-               skip = 0
-       } else if skip > 0 {
+       toSkip := MTHBlockSize - (offset - prepends*MTHBlockSize)
+       if toSkip == MTHBlockSize {
+               toSkip = 0
+       } else if toSkip > 0 {
                prepends++
        }
        prependSize := prepends * MTHBlockSize
+       mth.ctr = prepends
        if prependSize > size {
                prependSize = size
        }
-       if offset+skip > size {
-               skip = size - offset
+       if offset+toSkip > size {
+               toSkip = size - offset
        }
        mth.size = size
        mth.prependSize = prependSize
-       mth.skip = skip
-       mth.hashes = make([][MTHSize]byte, prepends, 1+size/MTHBlockSize)
+       mth.toSkip = toSkip
        return &mth
 }
 
-func (mth *MTHFat) Events() chan MTHEvent {
+func (mth *MTHSeq) Reset() { panic("not implemented") }
+
+func (mth *MTHSeq) Size() int { return MTHSize }
+
+func (mth *MTHSeq) BlockSize() int { return MTHBlockSize }
+
+func (mth *MTHSeq) PreaddFrom(r io.Reader, pktName string, showPrgrs bool) (int64, error) {
+       if mth.finished {
+               return 0, errors.New("already Sum()ed")
+       }
+       if mth.buf.Len() > 0 {
+               if _, err := mth.hasherLeaf.Write(mth.buf.Next(MTHBlockSize)); err != nil {
+                       panic(err)
+               }
+               mth.leafAdd()
+               mth.fold()
+       }
+       prevHashes := mth.hashes
+       mth.hashes = nil
+       prevCtr := mth.ctr
+       mth.ctr = 0
+       lr := io.LimitedReader{R: r, N: mth.prependSize}
+       les := LEs{{"Pkt", pktName}, {"FullSize", mth.prependSize}}
+       n, err := CopyProgressed(mth, &lr, "prehash", les, showPrgrs)
+       for _, ent := range prevHashes {
+               mth.hashes = append(mth.hashes, ent)
+               mth.fold()
+       }
+       if mth.buf.Len() > 0 {
+               mth.ctr = prevCtr - 1
+       } else {
+               mth.ctr = prevCtr
+       }
+       return n, err
+}
+
+func (mth *MTHSeq) Events() chan MTHEvent {
        mth.events = make(chan MTHEvent)
        return mth.events
 }
 
-func (mth *MTHFat) SetPktName(pktName string) { mth.pktName = pktName }
-
-func (mth *MTHFat) PrependSize() int64 { return mth.prependSize }
+func (mth *MTHSeq) PreaddSize() int64 { return mth.prependSize }
 
-func (mth *MTHFat) Reset() { panic("not implemented") }
-
-func (mth *MTHFat) Size() int { return MTHSize }
+func (mth *MTHSeq) leafAdd() {
+       ent := MTHSeqEnt{c: mth.ctr}
+       mth.hasherLeaf.Sum(ent.h[:0])
+       mth.hasherLeaf.Reset()
+       mth.hashes = append(mth.hashes, ent)
+       mth.ctr++
+       if mth.events != nil {
+               mth.events <- MTHEvent{MTHEventAdd, &ent}
+       }
+}
 
-func (mth *MTHFat) BlockSize() int { return MTHBlockSize }
+func (mth *MTHSeq) fold() {
+       for len(mth.hashes) >= 2 {
+               hlen := len(mth.hashes)
+               end1 := &mth.hashes[hlen-2]
+               end0 := &mth.hashes[hlen-1]
+               if end1.c%2 == 1 {
+                       break
+               }
+               if end1.l != end0.l {
+                       break
+               }
+               if _, err := mth.hasherNode.Write(end1.h[:]); err != nil {
+                       panic(err)
+               }
+               if _, err := mth.hasherNode.Write(end0.h[:]); err != nil {
+                       panic(err)
+               }
+               mth.hashes = mth.hashes[:hlen-1]
+               end1.l++
+               end1.c /= 2
+               mth.hasherNode.Sum(end1.h[:0])
+               mth.hasherNode.Reset()
+               if mth.events != nil {
+                       mth.events <- MTHEvent{MTHEventFold, end1}
+               }
+       }
+}
 
-func (mth *MTHFat) Write(data []byte) (int, error) {
+func (mth *MTHSeq) Write(data []byte) (int, error) {
        if mth.finished {
                return 0, errors.New("already Sum()ed")
        }
@@ -124,86 +207,121 @@ func (mth *MTHFat) Write(data []byte) (int, error) {
        if err != nil {
                return n, err
        }
-       if mth.skip > 0 && int64(mth.buf.Len()) >= mth.skip {
-               mth.buf.Next(int(mth.skip))
-               mth.skip = 0
+       if mth.toSkip > 0 {
+               if int64(mth.buf.Len()) < mth.toSkip {
+                       return n, err
+               }
+               mth.buf.Next(int(mth.toSkip))
+               mth.toSkip = 0
        }
        for mth.buf.Len() >= MTHBlockSize {
-               if _, err = mth.hasher.Write(mth.buf.Next(MTHBlockSize)); err != nil {
+               if _, err = mth.hasherLeaf.Write(mth.buf.Next(MTHBlockSize)); err != nil {
                        return n, err
                }
-               h := new([MTHSize]byte)
-               mth.hasher.Sum(h[:0])
-               mth.hasher.Reset()
-               mth.hashes = append(mth.hashes, *h)
-               if mth.events != nil {
-                       mth.events <- MTHEvent{
-                               MTHEventAppend,
-                               0, int64(len(mth.hashes) - 1),
-                               mth.hashes[len(mth.hashes)-1][:],
-                       }
-               }
+               mth.leafAdd()
+               mth.fold()
        }
        return n, err
 }
 
-func (mth *MTHFat) PrependFrom(r io.Reader) (int64, error) {
+func (mth *MTHSeq) Sum(b []byte) []byte {
        if mth.finished {
-               return 0, errors.New("already Sum()ed")
+               return append(b, mth.hashes[0].h[:]...)
        }
-       var err error
-       buf := make([]byte, MTHBlockSize)
-       var n int
-       var i, read int64
-       fullsize := mth.prependSize
-       les := LEs{{"Pkt", mth.pktName}, {"FullSize", fullsize}, {"Size", 0}}
-       for mth.prependSize >= MTHBlockSize {
-               n, err = io.ReadFull(r, buf)
-               read += int64(n)
-               mth.prependSize -= MTHBlockSize
-               if err != nil {
-                       return read, err
+       if mth.buf.Len() > 0 {
+               if _, err := mth.hasherLeaf.Write(mth.buf.Next(MTHBlockSize)); err != nil {
+                       panic(err)
                }
-               if _, err = mth.hasher.Write(buf); err != nil {
+               mth.leafAdd()
+               mth.fold()
+       }
+       switch mth.ctr {
+       case 0:
+               if _, err := mth.hasherLeaf.Write(nil); err != nil {
                        panic(err)
                }
-               mth.hasher.Sum(mth.hashes[i][:0])
-               mth.hasher.Reset()
+               mth.leafAdd()
+               fallthrough
+       case 1:
+               ent := MTHSeqEnt{c: 1}
+               copy(ent.h[:], mth.hashes[0].h[:])
+               mth.ctr = 2
+               mth.hashes = append(mth.hashes, ent)
                if mth.events != nil {
-                       mth.events <- MTHEvent{MTHEventPrepend, 0, i, mth.hashes[i][:]}
+                       mth.events <- MTHEvent{MTHEventAdd, &ent}
                }
-               if mth.pktName != "" {
-                       les[len(les)-1].V = read
-                       Progress("check", les)
-               }
-               i++
+               mth.fold()
        }
-       if mth.prependSize > 0 {
-               n, err = io.ReadFull(r, buf[:mth.prependSize])
-               read += int64(n)
-               if err != nil {
-                       return read, err
+       for len(mth.hashes) >= 2 {
+               hlen := len(mth.hashes)
+               end1 := &mth.hashes[hlen-2]
+               end0 := &mth.hashes[hlen-1]
+               end0.l = end1.l
+               end0.c = end1.c + 1
+               if mth.events != nil {
+                       mth.events <- MTHEvent{MTHEventAdd, end0}
                }
-               if _, err = mth.hasher.Write(buf[:mth.prependSize]); err != nil {
-                       panic(err)
+               mth.fold()
+       }
+       mth.finished = true
+       if mth.events != nil {
+               close(mth.events)
+       }
+       return append(b, mth.hashes[0].h[:]...)
+}
+
+func MTHNew(size, offset int64) MTH {
+       return MTHSeqNew(size, offset)
+}
+
+// Some kind of reference implementation (fat, because eats memory)
+
+type MTHFat struct {
+       hasher *blake3.Hasher
+       hashes [][MTHSize]byte
+       buf    *bytes.Buffer
+       events chan MTHEvent
+}
+
+func MTHFatNew() *MTHFat {
+       return &MTHFat{
+               hasher: blake3.New(MTHSize, MTHLeafKey[:]),
+               buf:    bytes.NewBuffer(make([]byte, 0, 2*MTHBlockSize)),
+       }
+}
+
+func (mth *MTHFat) Events() chan MTHEvent {
+       mth.events = make(chan MTHEvent)
+       return mth.events
+}
+
+func (mth *MTHFat) Write(data []byte) (int, error) {
+       n, err := mth.buf.Write(data)
+       if err != nil {
+               return n, err
+       }
+       for mth.buf.Len() >= MTHBlockSize {
+               if _, err = mth.hasher.Write(mth.buf.Next(MTHBlockSize)); err != nil {
+                       return n, err
                }
-               mth.hasher.Sum(mth.hashes[i][:0])
+               h := new([MTHSize]byte)
+               mth.hasher.Sum(h[:0])
                mth.hasher.Reset()
+               mth.hashes = append(mth.hashes, *h)
                if mth.events != nil {
-                       mth.events <- MTHEvent{MTHEventPrepend, 0, i, mth.hashes[i][:]}
-               }
-               if mth.pktName != "" {
-                       les[len(les)-1].V = fullsize
-                       Progress("check", les)
+                       mth.events <- MTHEvent{
+                               MTHEventAdd,
+                               &MTHSeqEnt{
+                                       0, int64(len(mth.hashes) - 1),
+                                       mth.hashes[len(mth.hashes)-1],
+                               },
+                       }
                }
        }
-       return read, nil
+       return n, err
 }
 
 func (mth *MTHFat) Sum(b []byte) []byte {
-       if mth.finished {
-               return append(b, mth.hashes[0][:]...)
-       }
        if mth.buf.Len() > 0 {
                b := mth.buf.Next(MTHBlockSize)
                if _, err := mth.hasher.Write(b); err != nil {
@@ -215,9 +333,11 @@ func (mth *MTHFat) Sum(b []byte) []byte {
                mth.hashes = append(mth.hashes, *h)
                if mth.events != nil {
                        mth.events <- MTHEvent{
-                               MTHEventAppend,
-                               0, int64(len(mth.hashes) - 1),
-                               mth.hashes[len(mth.hashes)-1][:],
+                               MTHEventAdd,
+                               &MTHSeqEnt{
+                                       0, int64(len(mth.hashes) - 1),
+                                       mth.hashes[len(mth.hashes)-1],
+                               },
                        }
                }
        }
@@ -231,17 +351,17 @@ func (mth *MTHFat) Sum(b []byte) []byte {
                mth.hasher.Reset()
                mth.hashes = append(mth.hashes, *h)
                if mth.events != nil {
-                       mth.events <- MTHEvent{MTHEventAppend, 0, 0, mth.hashes[0][:]}
+                       mth.events <- MTHEvent{MTHEventAdd, &MTHSeqEnt{0, 0, mth.hashes[0]}}
                }
                fallthrough
        case 1:
                mth.hashes = append(mth.hashes, mth.hashes[0])
                if mth.events != nil {
-                       mth.events <- MTHEvent{MTHEventAppend, 0, 1, mth.hashes[1][:]}
+                       mth.events <- MTHEvent{MTHEventAdd, &MTHSeqEnt{0, 1, mth.hashes[1]}}
                }
        }
        mth.hasher = blake3.New(MTHSize, MTHNodeKey[:])
-       level := int64(1)
+       level := 1
        for len(mth.hashes) != 1 {
                hashesUp := make([][MTHSize]byte, 0, 1+len(mth.hashes)/2)
                pairs := (len(mth.hashes) / 2) * 2
@@ -259,8 +379,10 @@ func (mth *MTHFat) Sum(b []byte) []byte {
                        if mth.events != nil {
                                mth.events <- MTHEvent{
                                        MTHEventFold,
-                                       level, int64(len(hashesUp) - 1),
-                                       hashesUp[len(hashesUp)-1][:],
+                                       &MTHSeqEnt{
+                                               level, int64(len(hashesUp) - 1),
+                                               hashesUp[len(hashesUp)-1],
+                                       },
                                }
                        }
                }
@@ -268,180 +390,19 @@ func (mth *MTHFat) Sum(b []byte) []byte {
                        hashesUp = append(hashesUp, mth.hashes[len(mth.hashes)-1])
                        if mth.events != nil {
                                mth.events <- MTHEvent{
-                                       MTHEventAppend,
-                                       level, int64(len(hashesUp) - 1),
-                                       hashesUp[len(hashesUp)-1][:],
+                                       MTHEventAdd,
+                                       &MTHSeqEnt{
+                                               level, int64(len(hashesUp) - 1),
+                                               hashesUp[len(hashesUp)-1],
+                                       },
                                }
                        }
                }
                mth.hashes = hashesUp
                level++
        }
-       mth.finished = true
        if mth.events != nil {
                close(mth.events)
        }
        return append(b, mth.hashes[0][:]...)
 }
-
-type MTHSeqEnt struct {
-       l int64
-       h [MTHSize]byte
-}
-
-type MTHSeq struct {
-       hasherLeaf *blake3.Hasher
-       hasherNode *blake3.Hasher
-       hashes     []MTHSeqEnt
-       buf        *bytes.Buffer
-       events     chan MTHEvent
-       ctrs       []int64
-       finished   bool
-}
-
-func MTHSeqNew() *MTHSeq {
-       mth := MTHSeq{
-               hasherLeaf: blake3.New(MTHSize, MTHLeafKey[:]),
-               hasherNode: blake3.New(MTHSize, MTHNodeKey[:]),
-               buf:        bytes.NewBuffer(make([]byte, 0, 2*MTHBlockSize)),
-               ctrs:       make([]int64, 1, 2),
-       }
-       return &mth
-}
-
-func (mth *MTHSeq) Reset() { panic("not implemented") }
-
-func (mth *MTHSeq) Size() int { return MTHSize }
-
-func (mth *MTHSeq) BlockSize() int { return MTHBlockSize }
-
-func (mth *MTHSeq) PrependFrom(r io.Reader) (int64, error) {
-       panic("must not reach that code")
-}
-
-func (mth *MTHSeq) Events() chan MTHEvent {
-       mth.events = make(chan MTHEvent)
-       return mth.events
-}
-
-func (mth *MTHSeq) SetPktName(pktName string) {}
-
-func (mth *MTHSeq) PrependSize() int64 { return 0 }
-
-func (mth *MTHSeq) leafAdd() {
-       ent := MTHSeqEnt{l: 0}
-       mth.hasherLeaf.Sum(ent.h[:0])
-       mth.hasherLeaf.Reset()
-       mth.hashes = append(mth.hashes, ent)
-       if mth.events != nil {
-               mth.events <- MTHEvent{
-                       MTHEventAppend, 0, mth.ctrs[0],
-                       mth.hashes[len(mth.hashes)-1].h[:],
-               }
-       }
-       mth.ctrs[0]++
-}
-
-func (mth *MTHSeq) incr(l int64) {
-       if int64(len(mth.ctrs)) <= l {
-               mth.ctrs = append(mth.ctrs, 0)
-       } else {
-               mth.ctrs[l]++
-       }
-}
-
-func (mth *MTHSeq) fold() {
-       for len(mth.hashes) >= 2 {
-               if mth.hashes[len(mth.hashes)-2].l != mth.hashes[len(mth.hashes)-1].l {
-                       break
-               }
-               if _, err := mth.hasherNode.Write(mth.hashes[len(mth.hashes)-2].h[:]); err != nil {
-                       panic(err)
-               }
-               if _, err := mth.hasherNode.Write(mth.hashes[len(mth.hashes)-1].h[:]); err != nil {
-                       panic(err)
-               }
-               mth.hashes = mth.hashes[:len(mth.hashes)-1]
-               end := &mth.hashes[len(mth.hashes)-1]
-               end.l++
-               mth.incr(end.l)
-               mth.hasherNode.Sum(end.h[:0])
-               mth.hasherNode.Reset()
-               if mth.events != nil {
-                       mth.events <- MTHEvent{MTHEventFold, end.l, mth.ctrs[end.l], end.h[:]}
-               }
-       }
-}
-
-func (mth *MTHSeq) Write(data []byte) (int, error) {
-       if mth.finished {
-               return 0, errors.New("already Sum()ed")
-       }
-       n, err := mth.buf.Write(data)
-       if err != nil {
-               return n, err
-       }
-       for mth.buf.Len() >= MTHBlockSize {
-               if _, err = mth.hasherLeaf.Write(mth.buf.Next(MTHBlockSize)); err != nil {
-                       return n, err
-               }
-               mth.leafAdd()
-               mth.fold()
-       }
-       return n, err
-}
-
-func (mth *MTHSeq) Sum(b []byte) []byte {
-       if mth.finished {
-               return append(b, mth.hashes[0].h[:]...)
-       }
-       if mth.buf.Len() > 0 {
-               if _, err := mth.hasherLeaf.Write(mth.buf.Next(MTHBlockSize)); err != nil {
-                       panic(err)
-               }
-               mth.leafAdd()
-               mth.fold()
-       }
-       switch mth.ctrs[0] {
-       case 0:
-               if _, err := mth.hasherLeaf.Write(nil); err != nil {
-                       panic(err)
-               }
-               mth.leafAdd()
-               fallthrough
-       case 1:
-               mth.hashes = append(mth.hashes, mth.hashes[0])
-               mth.ctrs[0]++
-               if mth.events != nil {
-                       mth.events <- MTHEvent{
-                               MTHEventAppend, 0, mth.ctrs[0],
-                               mth.hashes[len(mth.hashes)-1].h[:],
-                       }
-               }
-               mth.fold()
-       }
-       for len(mth.hashes) >= 2 {
-               l := mth.hashes[len(mth.hashes)-2].l
-               mth.incr(l)
-               mth.hashes[len(mth.hashes)-1].l = l
-               if mth.events != nil {
-                       mth.events <- MTHEvent{
-                               MTHEventAppend, l, mth.ctrs[l],
-                               mth.hashes[len(mth.hashes)-1].h[:],
-                       }
-               }
-               mth.fold()
-       }
-       mth.finished = true
-       if mth.events != nil {
-               close(mth.events)
-       }
-       return append(b, mth.hashes[0].h[:]...)
-}
-
-func MTHNew(size, offset int64) MTH {
-       if offset == 0 {
-               return MTHSeqNew()
-       }
-       return MTHFatNew(size, offset)
-}
index a4737c09d7efbf6399420ecefc25d04eed074703..2c61c879f5ecff27e179e74d298ec15848336d29 100644 (file)
@@ -26,7 +26,7 @@ import (
        "lukechampine.com/blake3"
 )
 
-func TestMTHFatSymmetric(t *testing.T) {
+func TestMTHSeqSymmetric(t *testing.T) {
        xof := blake3.New(32, nil).XOF()
        f := func(size uint32, offset uint32) bool {
                size %= 2 * 1024 * 1024
@@ -36,31 +36,31 @@ func TestMTHFatSymmetric(t *testing.T) {
                }
                offset = offset % size
 
-               mth := MTHFatNew(int64(size), 0)
+               mth := MTHSeqNew(int64(size), 0)
                if _, err := io.Copy(mth, bytes.NewReader(data)); err != nil {
                        panic(err)
                }
                hsh0 := mth.Sum(nil)
 
-               mth = MTHFatNew(int64(size), int64(offset))
+               mth = MTHSeqNew(int64(size), int64(offset))
                if _, err := io.Copy(mth, bytes.NewReader(data[int(offset):])); err != nil {
                        panic(err)
                }
-               if _, err := mth.PrependFrom(bytes.NewReader(data)); err != nil {
+               if _, err := mth.PreaddFrom(bytes.NewReader(data), "", false); err != nil {
                        panic(err)
                }
                if bytes.Compare(hsh0, mth.Sum(nil)) != 0 {
                        return false
                }
 
-               mth = MTHFatNew(0, 0)
+               mth = MTHSeqNew(0, 0)
                mth.Write(data)
                if bytes.Compare(hsh0, mth.Sum(nil)) != 0 {
                        return false
                }
 
                data = append(data, 0)
-               mth = MTHFatNew(int64(size)+1, 0)
+               mth = MTHSeqNew(int64(size)+1, 0)
                if _, err := io.Copy(mth, bytes.NewReader(data)); err != nil {
                        panic(err)
                }
@@ -69,18 +69,18 @@ func TestMTHFatSymmetric(t *testing.T) {
                        return false
                }
 
-               mth = MTHFatNew(int64(size)+1, int64(offset))
+               mth = MTHSeqNew(int64(size)+1, int64(offset))
                if _, err := io.Copy(mth, bytes.NewReader(data[int(offset):])); err != nil {
                        panic(err)
                }
-               if _, err := mth.PrependFrom(bytes.NewReader(data)); err != nil {
+               if _, err := mth.PreaddFrom(bytes.NewReader(data), "", false); err != nil {
                        panic(err)
                }
                if bytes.Compare(hsh00, mth.Sum(nil)) != 0 {
                        return false
                }
 
-               mth = MTHFatNew(0, 0)
+               mth = MTHSeqNew(0, 0)
                mth.Write(data)
                if bytes.Compare(hsh00, mth.Sum(nil)) != 0 {
                        return false
@@ -101,12 +101,12 @@ func TestMTHSeqAndFatEqual(t *testing.T) {
                if _, err := io.ReadFull(xof, data); err != nil {
                        panic(err)
                }
-               fat := MTHFatNew(int64(size), 0)
+               fat := MTHFatNew()
                if _, err := io.Copy(fat, bytes.NewReader(data)); err != nil {
                        panic(err)
                }
                hshFat := fat.Sum(nil)
-               seq := MTHSeqNew()
+               seq := MTHSeqNew(int64(size), 0)
                if _, err := io.Copy(seq, bytes.NewReader(data)); err != nil {
                        panic(err)
                }
@@ -118,13 +118,13 @@ func TestMTHSeqAndFatEqual(t *testing.T) {
 }
 
 func TestMTHNull(t *testing.T) {
-       fat := MTHFatNew(0, 0)
+       fat := MTHFatNew()
        if _, err := fat.Write(nil); err != nil {
                t.Error(err)
        }
        hshFat := fat.Sum(nil)
 
-       seq := MTHSeqNew()
+       seq := MTHSeqNew(0, 0)
        if _, err := seq.Write(nil); err != nil {
                t.Error(err)
        }
index 58b9bb2f202ef21261b3775b159714fe5f973d48..66fa0375797d2d53fb1cacea0d6143bcf7b0096e 100644 (file)
--- a/src/sp.go
+++ b/src/sp.go
@@ -1447,7 +1447,7 @@ func (state *SPState) ProcessSP(payload []byte) ([][]byte, error) {
                        }
                        if hasherAndOffset != nil {
                                delete(state.fileHashers, filePath)
-                               if hasherAndOffset.mth.PrependSize() == 0 {
+                               if hasherAndOffset.mth.PreaddSize() == 0 {
                                        if bytes.Compare(hasherAndOffset.mth.Sum(nil), file.Hash[:]) != 0 {
                                                state.Ctx.LogE(
                                                        "sp-file-bad-checksum", lesp,