]> Cypherpunks.ru repositories - nncp.git/blobdiff - src/mth.go
Sequential MTH optimization
[nncp.git] / src / mth.go
index 6cb1f3640b6f180fa363f6b8c9b7664f4ff0787a..a766925cee3b6e9b3db90a679486bc97d5fc3103 100644 (file)
@@ -20,6 +20,7 @@ package nncp
 import (
        "bytes"
        "errors"
+       "hash"
        "io"
 
        "lukechampine.com/blake3"
@@ -50,21 +51,29 @@ type MTHEvent struct {
        Hsh   []byte
 }
 
-type MTH struct {
+type MTH interface {
+       hash.Hash
+       PrependFrom(r io.Reader) (int, error)
+       SetPktName(n string)
+       PrependSize() int64
+       Events() chan MTHEvent
+}
+
+type MTHFat struct {
        size        int64
-       PrependSize int64
+       prependSize int64
        skip        int64
        skipped     bool
        hasher      *blake3.Hasher
        hashes      [][MTHSize]byte
        buf         *bytes.Buffer
        finished    bool
-       Events      chan MTHEvent
-       PktName     string
+       events      chan MTHEvent
+       pktName     string
 }
 
-func MTHNew(size, offset int64) *MTH {
-       mth := MTH{
+func MTHFatNew(size, offset int64) MTH {
+       mth := MTHFat{
                hasher: blake3.New(MTHSize, MTHLeafKey[:]),
                buf:    bytes.NewBuffer(make([]byte, 0, 2*MTHBlockSize)),
        }
@@ -86,19 +95,28 @@ func MTHNew(size, offset int64) *MTH {
                skip = size - offset
        }
        mth.size = size
-       mth.PrependSize = prependSize
+       mth.prependSize = prependSize
        mth.skip = skip
        mth.hashes = make([][MTHSize]byte, prepends, 1+size/MTHBlockSize)
        return &mth
 }
 
-func (mth *MTH) Reset() { panic("not implemented") }
+func (mth *MTHFat) Events() chan MTHEvent {
+       mth.events = make(chan MTHEvent)
+       return mth.events
+}
+
+func (mth *MTHFat) SetPktName(pktName string) { mth.pktName = pktName }
+
+func (mth *MTHFat) PrependSize() int64 { return mth.prependSize }
 
-func (mth *MTH) Size() int { return MTHSize }
+func (mth *MTHFat) Reset() { panic("not implemented") }
 
-func (mth *MTH) BlockSize() int { return MTHBlockSize }
+func (mth *MTHFat) Size() int { return MTHSize }
 
-func (mth *MTH) Write(data []byte) (int, error) {
+func (mth *MTHFat) BlockSize() int { return MTHBlockSize }
+
+func (mth *MTHFat) Write(data []byte) (int, error) {
        if mth.finished {
                return 0, errors.New("already Sum()ed")
        }
@@ -118,8 +136,8 @@ func (mth *MTH) Write(data []byte) (int, error) {
                mth.hasher.Sum(h[:0])
                mth.hasher.Reset()
                mth.hashes = append(mth.hashes, *h)
-               if mth.Events != nil {
-                       mth.Events <- MTHEvent{
+               if mth.events != nil {
+                       mth.events <- MTHEvent{
                                MTHEventAppend,
                                0, len(mth.hashes) - 1,
                                mth.hashes[len(mth.hashes)-1][:],
@@ -129,19 +147,19 @@ func (mth *MTH) Write(data []byte) (int, error) {
        return n, err
 }
 
-func (mth *MTH) PrependFrom(r io.Reader) (int, error) {
+func (mth *MTHFat) PrependFrom(r io.Reader) (int, error) {
        if mth.finished {
                return 0, errors.New("already Sum()ed")
        }
        var err error
        buf := make([]byte, MTHBlockSize)
        var i, n, read int
-       fullsize := mth.PrependSize
-       les := LEs{{"Pkt", mth.PktName}, {"FullSize", fullsize}, {"Size", 0}}
-       for mth.PrependSize >= MTHBlockSize {
+       fullsize := mth.prependSize
+       les := LEs{{"Pkt", mth.pktName}, {"FullSize", fullsize}, {"Size", 0}}
+       for mth.prependSize >= MTHBlockSize {
                n, err = io.ReadFull(r, buf)
                read += n
-               mth.PrependSize -= MTHBlockSize
+               mth.prependSize -= MTHBlockSize
                if err != nil {
                        return read, err
                }
@@ -150,30 +168,30 @@ func (mth *MTH) PrependFrom(r io.Reader) (int, error) {
                }
                mth.hasher.Sum(mth.hashes[i][:0])
                mth.hasher.Reset()
-               if mth.Events != nil {
-                       mth.Events <- MTHEvent{MTHEventPrepend, 0, i, mth.hashes[i][:]}
+               if mth.events != nil {
+                       mth.events <- MTHEvent{MTHEventPrepend, 0, i, mth.hashes[i][:]}
                }
-               if mth.PktName != "" {
+               if mth.pktName != "" {
                        les[len(les)-1].V = int64(read)
                        Progress("check", les)
                }
                i++
        }
-       if mth.PrependSize > 0 {
-               n, err = io.ReadFull(r, buf[:mth.PrependSize])
+       if mth.prependSize > 0 {
+               n, err = io.ReadFull(r, buf[:mth.prependSize])
                read += n
                if err != nil {
                        return read, err
                }
-               if _, err = mth.hasher.Write(buf[:mth.PrependSize]); err != nil {
+               if _, err = mth.hasher.Write(buf[:mth.prependSize]); err != nil {
                        panic(err)
                }
                mth.hasher.Sum(mth.hashes[i][:0])
                mth.hasher.Reset()
-               if mth.Events != nil {
-                       mth.Events <- MTHEvent{MTHEventPrepend, 0, i, mth.hashes[i][:]}
+               if mth.events != nil {
+                       mth.events <- MTHEvent{MTHEventPrepend, 0, i, mth.hashes[i][:]}
                }
-               if mth.PktName != "" {
+               if mth.pktName != "" {
                        les[len(les)-1].V = fullsize
                        Progress("check", les)
                }
@@ -181,7 +199,7 @@ func (mth *MTH) PrependFrom(r io.Reader) (int, error) {
        return read, nil
 }
 
-func (mth *MTH) Sum(b []byte) []byte {
+func (mth *MTHFat) Sum(b []byte) []byte {
        if mth.finished {
                return append(b, mth.hashes[0][:]...)
        }
@@ -194,8 +212,8 @@ func (mth *MTH) Sum(b []byte) []byte {
                mth.hasher.Sum(h[:0])
                mth.hasher.Reset()
                mth.hashes = append(mth.hashes, *h)
-               if mth.Events != nil {
-                       mth.Events <- MTHEvent{
+               if mth.events != nil {
+                       mth.events <- MTHEvent{
                                MTHEventAppend,
                                0, len(mth.hashes) - 1,
                                mth.hashes[len(mth.hashes)-1][:],
@@ -211,14 +229,14 @@ func (mth *MTH) Sum(b []byte) []byte {
                mth.hasher.Sum(h[:0])
                mth.hasher.Reset()
                mth.hashes = append(mth.hashes, *h)
-               if mth.Events != nil {
-                       mth.Events <- MTHEvent{MTHEventAppend, 0, 0, mth.hashes[0][:]}
+               if mth.events != nil {
+                       mth.events <- MTHEvent{MTHEventAppend, 0, 0, mth.hashes[0][:]}
                }
                fallthrough
        case 1:
                mth.hashes = append(mth.hashes, mth.hashes[0])
-               if mth.Events != nil {
-                       mth.Events <- MTHEvent{MTHEventAppend, 0, 1, mth.hashes[1][:]}
+               if mth.events != nil {
+                       mth.events <- MTHEvent{MTHEventAppend, 0, 1, mth.hashes[1][:]}
                }
        }
        mth.hasher = blake3.New(MTHSize, MTHNodeKey[:])
@@ -237,8 +255,8 @@ func (mth *MTH) Sum(b []byte) []byte {
                        mth.hasher.Sum(h[:0])
                        mth.hasher.Reset()
                        hashesUp = append(hashesUp, *h)
-                       if mth.Events != nil {
-                               mth.Events <- MTHEvent{
+                       if mth.events != nil {
+                               mth.events <- MTHEvent{
                                        MTHEventFold,
                                        level, len(hashesUp) - 1,
                                        hashesUp[len(hashesUp)-1][:],
@@ -247,8 +265,8 @@ func (mth *MTH) Sum(b []byte) []byte {
                }
                if len(mth.hashes)%2 == 1 {
                        hashesUp = append(hashesUp, mth.hashes[len(mth.hashes)-1])
-                       if mth.Events != nil {
-                               mth.Events <- MTHEvent{
+                       if mth.events != nil {
+                               mth.events <- MTHEvent{
                                        MTHEventAppend,
                                        level, len(hashesUp) - 1,
                                        hashesUp[len(hashesUp)-1][:],
@@ -259,8 +277,170 @@ func (mth *MTH) Sum(b []byte) []byte {
                level++
        }
        mth.finished = true
-       if mth.Events != nil {
-               close(mth.Events)
+       if mth.events != nil {
+               close(mth.events)
        }
        return append(b, mth.hashes[0][:]...)
 }
+
+type MTHSeqEnt struct {
+       l int
+       h [MTHSize]byte
+}
+
+type MTHSeq struct {
+       hasherLeaf *blake3.Hasher
+       hasherNode *blake3.Hasher
+       hashes     []MTHSeqEnt
+       buf        *bytes.Buffer
+       events     chan MTHEvent
+       ctrs       []int
+       finished   bool
+}
+
+func MTHSeqNew() *MTHSeq {
+       mth := MTHSeq{
+               hasherLeaf: blake3.New(MTHSize, MTHLeafKey[:]),
+               hasherNode: blake3.New(MTHSize, MTHNodeKey[:]),
+               buf:        bytes.NewBuffer(make([]byte, 0, 2*MTHBlockSize)),
+               ctrs:       make([]int, 1, 2),
+       }
+       return &mth
+}
+
+func (mth *MTHSeq) Reset() { panic("not implemented") }
+
+func (mth *MTHSeq) Size() int { return MTHSize }
+
+func (mth *MTHSeq) BlockSize() int { return MTHBlockSize }
+
+func (mth *MTHSeq) PrependFrom(r io.Reader) (int, error) {
+       panic("must not reach that code")
+}
+
+func (mth *MTHSeq) Events() chan MTHEvent {
+       mth.events = make(chan MTHEvent)
+       return mth.events
+}
+
+func (mth *MTHSeq) SetPktName(pktName string) {}
+
+func (mth *MTHSeq) PrependSize() int64 { return 0 }
+
+func (mth *MTHSeq) leafAdd() {
+       ent := MTHSeqEnt{l: 0}
+       mth.hasherLeaf.Sum(ent.h[:0])
+       mth.hasherLeaf.Reset()
+       mth.hashes = append(mth.hashes, ent)
+       if mth.events != nil {
+               mth.events <- MTHEvent{
+                       MTHEventAppend, 0, mth.ctrs[0],
+                       mth.hashes[len(mth.hashes)-1].h[:],
+               }
+       }
+       mth.ctrs[0]++
+}
+
+func (mth *MTHSeq) incr(l int) {
+       if len(mth.ctrs) <= l {
+               mth.ctrs = append(mth.ctrs, 0)
+       } else {
+               mth.ctrs[l]++
+       }
+}
+
+func (mth *MTHSeq) fold() {
+       for len(mth.hashes) >= 2 {
+               if mth.hashes[len(mth.hashes)-2].l != mth.hashes[len(mth.hashes)-1].l {
+                       break
+               }
+               if _, err := mth.hasherNode.Write(mth.hashes[len(mth.hashes)-2].h[:]); err != nil {
+                       panic(err)
+               }
+               if _, err := mth.hasherNode.Write(mth.hashes[len(mth.hashes)-1].h[:]); err != nil {
+                       panic(err)
+               }
+               mth.hashes = mth.hashes[:len(mth.hashes)-1]
+               end := &mth.hashes[len(mth.hashes)-1]
+               end.l++
+               mth.incr(end.l)
+               mth.hasherNode.Sum(end.h[:0])
+               mth.hasherNode.Reset()
+               if mth.events != nil {
+                       mth.events <- MTHEvent{MTHEventFold, end.l, mth.ctrs[end.l], end.h[:]}
+               }
+       }
+}
+
+func (mth *MTHSeq) Write(data []byte) (int, error) {
+       if mth.finished {
+               return 0, errors.New("already Sum()ed")
+       }
+       n, err := mth.buf.Write(data)
+       if err != nil {
+               return n, err
+       }
+       for mth.buf.Len() >= MTHBlockSize {
+               if _, err = mth.hasherLeaf.Write(mth.buf.Next(MTHBlockSize)); err != nil {
+                       return n, err
+               }
+               mth.leafAdd()
+               mth.fold()
+       }
+       return n, err
+}
+
+func (mth *MTHSeq) Sum(b []byte) []byte {
+       if mth.finished {
+               return append(b, mth.hashes[0].h[:]...)
+       }
+       if mth.buf.Len() > 0 {
+               if _, err := mth.hasherLeaf.Write(mth.buf.Next(MTHBlockSize)); err != nil {
+                       panic(err)
+               }
+               mth.leafAdd()
+               mth.fold()
+       }
+       switch mth.ctrs[0] {
+       case 0:
+               if _, err := mth.hasherLeaf.Write(nil); err != nil {
+                       panic(err)
+               }
+               mth.leafAdd()
+               fallthrough
+       case 1:
+               mth.hashes = append(mth.hashes, mth.hashes[0])
+               mth.ctrs[0]++
+               if mth.events != nil {
+                       mth.events <- MTHEvent{
+                               MTHEventAppend, 0, mth.ctrs[0],
+                               mth.hashes[len(mth.hashes)-1].h[:],
+                       }
+               }
+               mth.fold()
+       }
+       for len(mth.hashes) >= 2 {
+               l := mth.hashes[len(mth.hashes)-2].l
+               mth.incr(l)
+               mth.hashes[len(mth.hashes)-1].l = l
+               if mth.events != nil {
+                       mth.events <- MTHEvent{
+                               MTHEventAppend, l, mth.ctrs[l],
+                               mth.hashes[len(mth.hashes)-1].h[:],
+                       }
+               }
+               mth.fold()
+       }
+       mth.finished = true
+       if mth.events != nil {
+               close(mth.events)
+       }
+       return append(b, mth.hashes[0].h[:]...)
+}
+
+func MTHNew(size, offset int64) MTH {
+       if offset == 0 {
+               return MTHSeqNew()
+       }
+       return MTHFatNew(size, offset)
+}