]> Cypherpunks.ru repositories - gogost.git/blobdiff - src/cypherpunks.ru/gogost/gost3410/public.go
Make Public/PrivateKey structure elements public for convenience
[gogost.git] / src / cypherpunks.ru / gogost / gost3410 / public.go
index e8865021122365ea57f64c78dd231e5e0bc704bb..d8d769ff3e30fc94c2f8893c5cf6cbd5b42e687f 100644 (file)
@@ -1,5 +1,5 @@
 // GoGOST -- Pure Go GOST cryptographic functions library
-// Copyright (C) 2015-2016 Sergey Matveev <stargrave@stargrave.org>
+// Copyright (C) 2015-2019 Sergey Matveev <stargrave@stargrave.org>
 //
 // This program is free software: you can redistribute it and/or modify
 // it under the terms of the GNU General Public License as published by
@@ -22,84 +22,85 @@ import (
 )
 
 type PublicKey struct {
-       c  *Curve
-       ds int
-       x  *big.Int
-       y  *big.Int
+       C    *Curve
+       Mode Mode
+       X    *big.Int
+       Y    *big.Int
 }
 
-func NewPublicKey(curve *Curve, ds DigestSize, raw []byte) (*PublicKey, error) {
-       if len(raw) != 2*int(ds) {
+func NewPublicKey(curve *Curve, mode Mode, raw []byte) (*PublicKey, error) {
+       key := make([]byte, 2*int(mode))
+       if len(raw) != len(key) {
                return nil, errors.New("Invalid public key length")
        }
-       key := make([]byte, 2*int(ds))
-       copy(key, raw)
-       reverse(key)
+       for i := 0; i < len(key); i++ {
+               key[i] = raw[len(raw)-i-1]
+       }
        return &PublicKey{
                curve,
-               int(ds),
-               bytes2big(key[int(ds) : 2*int(ds)]),
-               bytes2big(key[:int(ds)]),
+               mode,
+               bytes2big(key[int(mode) : 2*int(mode)]),
+               bytes2big(key[:int(mode)]),
        }, nil
 }
 
-func (pk *PublicKey) Raw() []byte {
-       raw := append(pad(pk.y.Bytes(), pk.ds), pad(pk.x.Bytes(), pk.ds)...)
+func (pub *PublicKey) Raw() []byte {
+       raw := append(
+               pad(pub.Y.Bytes(), int(pub.Mode)),
+               pad(pub.X.Bytes(), int(pub.Mode))...,
+       )
        reverse(raw)
        return raw
 }
 
-func (pk *PublicKey) VerifyDigest(digest, signature []byte) (bool, error) {
-       if len(digest) != pk.ds {
-               return false, errors.New("Invalid input digest length")
-       }
-       if len(signature) != 2*pk.ds {
+func (pub *PublicKey) VerifyDigest(digest, signature []byte) (bool, error) {
+       if len(signature) != 2*int(pub.Mode) {
                return false, errors.New("Invalid signature length")
        }
-       s := bytes2big(signature[:pk.ds])
-       r := bytes2big(signature[pk.ds:])
-       if r.Cmp(zero) <= 0 || r.Cmp(pk.c.Q) >= 0 || s.Cmp(zero) <= 0 || s.Cmp(pk.c.Q) >= 0 {
+       s := bytes2big(signature[:pub.Mode])
+       r := bytes2big(signature[pub.Mode:])
+       if r.Cmp(zero) <= 0 || r.Cmp(pub.C.Q) >= 0 || s.Cmp(zero) <= 0 || s.Cmp(pub.C.Q) >= 0 {
                return false, nil
        }
        e := bytes2big(digest)
-       e.Mod(e, pk.c.Q)
+       e.Mod(e, pub.C.Q)
        if e.Cmp(zero) == 0 {
                e = big.NewInt(1)
        }
        v := big.NewInt(0)
-       v.ModInverse(e, pk.c.Q)
+       v.ModInverse(e, pub.C.Q)
        z1 := big.NewInt(0)
        z2 := big.NewInt(0)
        z1.Mul(s, v)
-       z1.Mod(z1, pk.c.Q)
+       z1.Mod(z1, pub.C.Q)
        z2.Mul(r, v)
-       z2.Mod(z2, pk.c.Q)
-       z2.Sub(pk.c.Q, z2)
-       p1x, p1y, err := pk.c.Exp(z1, pk.c.Bx, pk.c.By)
+       z2.Mod(z2, pub.C.Q)
+       z2.Sub(pub.C.Q, z2)
+       p1x, p1y, err := pub.C.Exp(z1, pub.C.X, pub.C.Y)
        if err != nil {
                return false, err
        }
-       q1x, q1y, err := pk.c.Exp(z2, pk.x, pk.y)
+       q1x, q1y, err := pub.C.Exp(z2, pub.X, pub.Y)
        if err != nil {
                return false, err
        }
        lm := big.NewInt(0)
        lm.Sub(q1x, p1x)
        if lm.Cmp(zero) < 0 {
-               lm.Add(lm, pk.c.P)
+               lm.Add(lm, pub.C.P)
        }
-       lm.ModInverse(lm, pk.c.P)
+       lm.ModInverse(lm, pub.C.P)
        z1.Sub(q1y, p1y)
        lm.Mul(lm, z1)
-       lm.Mod(lm, pk.c.P)
+       lm.Mod(lm, pub.C.P)
        lm.Mul(lm, lm)
-       lm.Mod(lm, pk.c.P)
+       lm.Mod(lm, pub.C.P)
        lm.Sub(lm, p1x)
        lm.Sub(lm, q1x)
-       lm.Mod(lm, pk.c.P)
+       lm.Mod(lm, pub.C.P)
        if lm.Cmp(zero) < 0 {
-               lm.Add(lm, pk.c.P)
+               lm.Add(lm, pub.C.P)
        }
-       lm.Mod(lm, pk.c.Q)
+       lm.Mod(lm, pub.C.Q)
        return lm.Cmp(r) == 0, nil
 }