]> Cypherpunks.ru repositories - pygost.git/blobdiff - pygost/gost3410.py
Remove excess mode kwargs from gost3410* functions
[pygost.git] / pygost / gost3410.py
index 9f0a11e656689745056b8f7c53cdf5c3779201b3..b518c95707b48c9d0f3c0741a25a2845c90b2ac0 100644 (file)
@@ -28,10 +28,10 @@ from pygost.utils import long2bytes
 from pygost.utils import modinvert
 
 
-MODE2SIZE = {
-    2001: 32,
-    2012: 64,
-}
+def point_size(point):
+    """Determine is it either 256 or 512 bit point
+    """
+    return (512 // 8) if point.bit_length() > 256 else (256 // 8)
 
 
 class GOST3410Curve(object):
@@ -70,6 +70,10 @@ class GOST3410Curve(object):
             raise ValueError("Invalid parameters")
         self._st = None
 
+    @property
+    def point_size(self):
+        return point_size(self.p)
+
     def pos(self, v):
         """Make positive number
         """
@@ -225,7 +229,7 @@ def public_key(curve, prv):
     return curve.exp(prv)
 
 
-def sign(curve, prv, digest, rand=None, mode=2001):
+def sign(curve, prv, digest, rand=None):
     """ Calculate signature for provided digest
 
     :param GOST3410Curve curve: curve to use
@@ -237,7 +241,7 @@ def sign(curve, prv, digest, rand=None, mode=2001):
     :returns: signature, BE(S) || BE(R)
     :rtype: bytes, 64 or 128 bytes
     """
-    size = MODE2SIZE[mode]
+    size = curve.point_size
     q = curve.q
     e = bytes2long(digest) % q
     if e == 0:
@@ -263,7 +267,7 @@ def sign(curve, prv, digest, rand=None, mode=2001):
     return long2bytes(s, size) + long2bytes(r, size)
 
 
-def verify(curve, pub, digest, signature, mode=2001):
+def verify(curve, pub, digest, signature):
     """ Verify provided digest with the signature
 
     :param GOST3410Curve curve: curve to use
@@ -274,7 +278,7 @@ def verify(curve, pub, digest, signature, mode=2001):
     :type signature: bytes, 64 or 128 bytes
     :rtype: bool
     """
-    size = MODE2SIZE[mode]
+    size = curve.point_size
     if len(signature) != size * 2:
         raise ValueError("Invalid signature length")
     q = curve.q
@@ -316,25 +320,25 @@ def prv_unmarshal(prv):
     return bytes2long(prv[::-1])
 
 
-def pub_marshal(pub, mode=2001):
+def pub_marshal(pub):
     """Marshal public key
 
     :type pub: (long, long)
     :rtype: bytes
     :returns: LE(X) || LE(Y)
     """
-    size = MODE2SIZE[mode]
+    size = point_size(pub[0])
     return (long2bytes(pub[1], size) + long2bytes(pub[0], size))[::-1]
 
 
-def pub_unmarshal(pub, mode=2001):
+def pub_unmarshal(pub):
     """Unmarshal public key
 
     :param pub: LE(X) || LE(Y)
     :type pub: bytes
     :rtype: (long, long)
     """
-    size = MODE2SIZE[mode]
+    size = len(pub) // 2
     pub = pub[::-1]
     return (bytes2long(pub[size:]), bytes2long(pub[:size]))