Source code for pvss.ristretto_255

"""
Binding to `libsodium <https://libsodium.org/>`_ to use `Ristretto255 <https://ristretto.group/>`_ for group operations.
"""

from __future__ import annotations

import ctypes
import ctypes.util
import hmac
from dataclasses import dataclass
from fractions import Fraction
from os import environ
from secrets import randbelow
from typing import TYPE_CHECKING, ByteString, Optional, Union

from asn1crypto.core import Asn1Value, OctetString

from . import asn1 as _asn1
from .groups import ImageGroup, ImageValue, PgvOrInt, PreGroup, PreGroupValue
from .pvss import Pvss, SystemParameters

if TYPE_CHECKING:  # pragma: no cover
    lazy = property
else:
    from lazy import lazy


# Order of the Ristretto255 group.
_RST_255_GROUP_ORDER = 2 ** 252 + 27742317777372353535851937790883648493


[docs]def create_ristretto_255_parameters(pvss: Pvss) -> bytes: """ Create and set Ristretto255 parameters. Args: pvss: Pvss object with public values Returns: DER encoded Ristretto255 system parameters. """ result = Ristretto255Parameters.create(pvss, None).der pvss.set_params(result) return result
class Ristretto255Parameters(SystemParameters): ALGO = "ristretto_255" @lazy def pre_group(self) -> Ristretto255ScalarGroup: return Ristretto255ScalarGroup() @lazy def img_group(self) -> Ristretto255Group: return Ristretto255Group(self.pre_group) def _make_gen(self, seed: str) -> Ristretto255Point: while True: dig = hmac.digest(seed.encode(), self.der, "sha512") gen = self.img_group.from_hash(dig) # Check for neutral elements; we don't want those. if gen: return gen else: # pragma: no cover # Try again with other seed. seed += "_" class _Lib: try: lib_name = ctypes.util.find_library("sodium") if not lib_name: # pragma: no cover raise Exception("libsodium not found") lib = ctypes.cdll.LoadLibrary(lib_name) if lib.sodium_init() < 0: # pragma: no cover raise Exception("Cannot initialize libsodium") # int sodium_memcmp(const void * const b1_, const void * const b2_, size_t len); memcmp = lib.sodium_memcmp memcmp.restype = ctypes.c_int memcmp.argtypes = ctypes.c_char_p, ctypes.c_char_p, ctypes.c_size_t # int sodium_is_zero(const unsigned char *n, const size_t nlen); is_zero = lib.sodium_is_zero is_zero.restype = ctypes.c_int is_zero.argtypes = ctypes.c_char_p, ctypes.c_size_t # int crypto_core_ristretto255_is_valid_point(const unsigned char *p); point_is_valid = lib.crypto_core_ristretto255_is_valid_point point_is_valid.restype = ctypes.c_int point_is_valid.argtypes = (ctypes.c_char_p,) # void crypto_core_ristretto255_random(unsigned char *p); point_random = lib.crypto_core_ristretto255_random point_random.restype = None point_random.argtypes = (ctypes.c_char_p,) # int crypto_core_ristretto255_from_hash(unsigned char *p, const unsigned char *r); point_from_hash = lib.crypto_core_ristretto255_from_hash point_from_hash.restype = ctypes.c_int point_from_hash.argtypes = ctypes.c_char_p, ctypes.c_char_p # int crypto_scalarmult_ristretto255(unsigned char *q, const unsigned char *n, const unsigned char *p); point_mul = lib.crypto_scalarmult_ristretto255 point_mul.restype = ctypes.c_int point_mul.argtypes = ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p # int crypto_scalarmult_ristretto255_base(unsigned char *q, const unsigned char *n); point_base_mul = lib.crypto_scalarmult_ristretto255_base point_base_mul.restype = ctypes.c_int point_base_mul.argtypes = ctypes.c_char_p, ctypes.c_char_p # int crypto_core_ristretto255_add(unsigned char *r, const unsigned char *p, const unsigned char *q); point_add = lib.crypto_core_ristretto255_add point_add.restype = ctypes.c_int point_add.argtypes = ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p # int crypto_core_ristretto255_sub(unsigned char *r, const unsigned char *p, const unsigned char *q); point_sub = lib.crypto_core_ristretto255_sub point_sub.restype = ctypes.c_int point_sub.argtypes = ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p # void crypto_core_ristretto255_scalar_random(unsigned char *r); scalar_random = lib.crypto_core_ristretto255_scalar_random scalar_random.restype = None scalar_random.argtypes = (ctypes.c_char_p,) # void crypto_core_ristretto255_scalar_reduce(unsigned char *r, const unsigned char *s); scalar_reduce = lib.crypto_core_ristretto255_scalar_reduce scalar_reduce.restype = None scalar_reduce.argtypes = ctypes.c_char_p, ctypes.c_char_p # int crypto_core_ristretto255_scalar_invert(unsigned char *recip, const unsigned char *s); scalar_invert = lib.crypto_core_ristretto255_scalar_invert scalar_invert.restype = ctypes.c_int scalar_invert.argtypes = ctypes.c_char_p, ctypes.c_char_p # void crypto_core_ristretto255_scalar_negate(unsigned char *neg, const unsigned char *s); scalar_negate = lib.crypto_core_ristretto255_scalar_negate scalar_negate.restype = None scalar_negate.argtypes = ctypes.c_char_p, ctypes.c_char_p # void crypto_core_ristretto255_scalar_complement(unsigned char *comp, const unsigned char *s); scalar_complement = lib.crypto_core_ristretto255_scalar_complement scalar_complement.restype = None scalar_complement.argtypes = ctypes.c_char_p, ctypes.c_char_p # void crypto_core_ristretto255_scalar_add(unsigned char *z, const unsigned char *x, const unsigned char *y); scalar_add = lib.crypto_core_ristretto255_scalar_add scalar_add.restype = None scalar_add.argtypes = ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p # void crypto_core_ristretto255_scalar_sub(unsigned char *z, const unsigned char *x, const unsigned char *y); scalar_sub = lib.crypto_core_ristretto255_scalar_sub scalar_sub.restype = None scalar_sub.argtypes = ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p # void crypto_core_ristretto255_scalar_mul(unsigned char *z, const unsigned char *x, const unsigned char *y); scalar_mul = lib.crypto_core_ristretto255_scalar_mul scalar_mul.restype = None scalar_mul.argtypes = ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p except Exception: # pragma: no cover # Work around the fact that libsodium is not installed in the readthedocs build image if "READTHEDOCS" not in environ: raise @dataclass(frozen=True, eq=False, repr=False) class Ristretto255Group(ImageGroup): pre_group: Ristretto255ScalarGroup def __call__(self, value: Union[Asn1Value]) -> Ristretto255Point: """ Create from serialized buffer, inverse of bytes() """ if not isinstance(value, _asn1.ImgGroupValue): raise TypeError(type(value)) buf = value.chosen if not isinstance(buf, OctetString): raise TypeError(type(buf)) res = ctypes.create_string_buffer(bytes(buf), 32) if not _Lib.point_is_valid(res): raise ValueError("Not a valid point") return Ristretto255Point(self, res) def random(self) -> Ristretto255Point: """ Generate random element. """ res = ctypes.create_string_buffer(32) _Lib.point_random(res) return Ristretto255Point(self, res) def from_hash(self, value: ByteString) -> Ristretto255Point: """ Generate a point from from up to 64 bytes. Those would usually come out of a hash function. """ buf = ctypes.create_string_buffer(bytes(value), 64) res = ctypes.create_string_buffer(32) if _Lib.point_from_hash(res, buf) < 0: raise Exception("Unknown error") # pragma: no cover return Ristretto255Point(self, res) @property def len(self) -> int: return _RST_255_GROUP_ORDER def __repr__(self) -> str: return "Ristretto255Group()" @dataclass(frozen=True, eq=False, repr=False) class Ristretto255Point(ImageValue): group: Ristretto255Group _buf: ctypes.Array[ctypes.c_char] @lazy def asn1(self) -> _asn1.ImgGroupValue: return _asn1.ImgGroupValue({"ECPoint": OctetString(bytes(self))}) def __pow__( self, other: Union[PgvOrInt, Fraction], modulo: Optional[int] = None ) -> Ristretto255Point: """ Compute self ** other """ if modulo is not None: raise TypeError("modulo must be None") if isinstance(other, (int, Fraction)): buf = self.group.pre_group(other)._buf elif isinstance(other, Ristretto255Scalar): buf = other._buf else: return NotImplemented res = ctypes.create_string_buffer(32) if _Lib.point_mul(res, buf, self._buf) < 0: raise ValueError("Zero") return Ristretto255Point(self.group, res) def __mul__(self, other: ImageValue) -> Ristretto255Point: """ Compute self * other """ if not isinstance(other, Ristretto255Point): return NotImplemented res = ctypes.create_string_buffer(32) if _Lib.point_add(res, self._buf, other._buf) < 0: raise ValueError("Encoding error") return Ristretto255Point(self.group, res) def __floordiv__(self, other: ImageValue) -> Ristretto255Point: """ Compute self / other """ if not isinstance(other, Ristretto255Point): return NotImplemented res = ctypes.create_string_buffer(32) if _Lib.point_sub(res, self._buf, other._buf) < 0: raise ValueError("Encoding error") return Ristretto255Point(self.group, res) def __bytes__(self) -> bytes: """ """ return bytes(self._buf) def __str__(self) -> str: """ """ return f"Ristretto255Point(0x{bytes(self).hex()})" def __repr__(self) -> str: """ """ return f"Ristretto255Point.from_bytes({bytes(self)!r})" def __eq__(self, other: object) -> bool: """ """ if not isinstance(other, Ristretto255Point): return False return not _Lib.memcmp(self._buf, other._buf, 32) def __bool__(self) -> bool: """ False if this is the neutral element """ return not _Lib.is_zero(self._buf, 32) def __hash__(self) -> int: return hash(bytes(self._buf)) class Ristretto255ScalarGroup(PreGroup): def __call__(self, value: Union[int, Asn1Value, Fraction]) -> Ristretto255Scalar: """ Convert an integer into a group element Returns: Group element """ if isinstance(value, _asn1.PreGroupValue): value = int(value) if not 0 <= value < _RST_255_GROUP_ORDER: raise ValueError("Not a valid group element") res = ctypes.create_string_buffer(value.to_bytes(32, "little"), 32) return Ristretto255Scalar(self, res) if isinstance(value, int): if value < 0: neg = True value = -value else: neg = False val = ctypes.create_string_buffer(value.to_bytes(64, "little"), 64) res = ctypes.create_string_buffer(32) _Lib.scalar_reduce(res, val) if neg: return -Ristretto255Scalar(self, res) else: return Ristretto255Scalar(self, res) if isinstance(value, Fraction): return self(value.numerator) * self(value.denominator).inv raise TypeError(type(value)) @property def len(self) -> int: """ Get number of elements in this group Returns: group size """ return _RST_255_GROUP_ORDER @property def rand(self) -> Ristretto255Scalar: """ Create random element of this group Returns: Random group element """ value = randbelow(_RST_255_GROUP_ORDER) res = ctypes.create_string_buffer(value.to_bytes(32, "little"), 32) return Ristretto255Scalar(self, res) @property def rand_nonzero(self) -> Ristretto255Scalar: """ Create random element of this group, but never the neutral element. Returns: Random group element """ res = ctypes.create_string_buffer(32) _Lib.scalar_random(res) return Ristretto255Scalar(self, res) def __repr__(self) -> str: """ Outputs a representation of this group. Returns: Representation of this group """ return "Ristretto255ScalarGroup()" @dataclass(frozen=True, eq=False, repr=False) class Ristretto255Scalar(PreGroupValue): group: Ristretto255ScalarGroup _buf: ctypes.Array[ctypes.c_char] def __neg__(self) -> Ristretto255Scalar: res = ctypes.create_string_buffer(32) _Lib.scalar_negate(res, self._buf) return Ristretto255Scalar(self.group, res) def __add__(self, other: Union[int, PreGroupValue]) -> Ristretto255Scalar: if isinstance(other, int): buf = self.group(other)._buf elif isinstance(other, Ristretto255Scalar): buf = other._buf else: return NotImplemented res = ctypes.create_string_buffer(32) _Lib.scalar_add(res, self._buf, buf) return Ristretto255Scalar(self.group, res) def __sub__(self, other: Union[int, PreGroupValue]) -> Ristretto255Scalar: if isinstance(other, int): buf = self.group(other)._buf elif isinstance(other, Ristretto255Scalar): buf = other._buf else: return NotImplemented res = ctypes.create_string_buffer(32) _Lib.scalar_sub(res, self._buf, buf) return Ristretto255Scalar(self.group, res) def __mul__(self, other: Union[int, PreGroupValue]) -> Ristretto255Scalar: if isinstance(other, int): buf = self.group(other)._buf elif isinstance(other, Ristretto255Scalar): buf = other._buf else: return NotImplemented res = ctypes.create_string_buffer(32) _Lib.scalar_mul(res, self._buf, buf) return Ristretto255Scalar(self.group, res) @lazy def inv(self) -> Ristretto255Scalar: res = ctypes.create_string_buffer(32) if _Lib.scalar_invert(res, self._buf) < 0: raise ValueError("Cannot invert value") return Ristretto255Scalar(self.group, res) def __eq__(self, other: object) -> bool: """ """ if not isinstance(other, Ristretto255Scalar): return False return not _Lib.memcmp(self._buf, other._buf, 32) def __bool__(self) -> bool: return not _Lib.is_zero(self._buf, 32) def __repr__(self) -> str: return f"Ristretto255ScalarGroup()({int(self)})" def __int__(self) -> int: return int.from_bytes(bytes(self._buf), "little") def __bytes__(self) -> bytes: return bytes(self._buf) @lazy def asn1(self) -> _asn1.PreGroupValue: return _asn1.PreGroupValue(int(self)) def __hash__(self) -> int: return hash(bytes(self))