refactor
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -0,0 +1,5 @@
|
|||||||
|
driver
|
||||||
|
venv
|
||||||
|
*.log
|
||||||
|
html
|
||||||
|
imgs
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,32 +0,0 @@
|
|||||||
# Copyright (C) AB Strakt
|
|
||||||
# See LICENSE for details.
|
|
||||||
|
|
||||||
"""
|
|
||||||
pyOpenSSL - A simple wrapper around the OpenSSL library
|
|
||||||
"""
|
|
||||||
|
|
||||||
from OpenSSL import crypto, SSL
|
|
||||||
from OpenSSL.version import (
|
|
||||||
__author__,
|
|
||||||
__copyright__,
|
|
||||||
__email__,
|
|
||||||
__license__,
|
|
||||||
__summary__,
|
|
||||||
__title__,
|
|
||||||
__uri__,
|
|
||||||
__version__,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"SSL",
|
|
||||||
"crypto",
|
|
||||||
"__author__",
|
|
||||||
"__copyright__",
|
|
||||||
"__email__",
|
|
||||||
"__license__",
|
|
||||||
"__summary__",
|
|
||||||
"__title__",
|
|
||||||
"__uri__",
|
|
||||||
"__version__",
|
|
||||||
]
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from cryptography.hazmat.bindings.openssl.binding import Binding
|
|
||||||
|
|
||||||
|
|
||||||
binding = Binding()
|
|
||||||
ffi = binding.ffi
|
|
||||||
lib = binding.lib
|
|
||||||
|
|
||||||
|
|
||||||
# This is a special CFFI allocator that does not bother to zero its memory
|
|
||||||
# after allocation. This has vastly better performance on large allocations and
|
|
||||||
# so should be used whenever we don't need the memory zeroed out.
|
|
||||||
no_zero_allocator = ffi.new_allocator(should_clear_after_alloc=False)
|
|
||||||
|
|
||||||
|
|
||||||
def text(charp):
|
|
||||||
"""
|
|
||||||
Get a native string type representing of the given CFFI ``char*`` object.
|
|
||||||
|
|
||||||
:param charp: A C-style string represented using CFFI.
|
|
||||||
|
|
||||||
:return: :class:`str`
|
|
||||||
"""
|
|
||||||
if not charp:
|
|
||||||
return ""
|
|
||||||
return ffi.string(charp).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def exception_from_error_queue(exception_type):
|
|
||||||
"""
|
|
||||||
Convert an OpenSSL library failure into a Python exception.
|
|
||||||
|
|
||||||
When a call to the native OpenSSL library fails, this is usually signalled
|
|
||||||
by the return value, and an error code is stored in an error queue
|
|
||||||
associated with the current thread. The err library provides functions to
|
|
||||||
obtain these error codes and textual error messages.
|
|
||||||
"""
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
while True:
|
|
||||||
error = lib.ERR_get_error()
|
|
||||||
if error == 0:
|
|
||||||
break
|
|
||||||
errors.append(
|
|
||||||
(
|
|
||||||
text(lib.ERR_lib_error_string(error)),
|
|
||||||
text(lib.ERR_func_error_string(error)),
|
|
||||||
text(lib.ERR_reason_error_string(error)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
raise exception_type(errors)
|
|
||||||
|
|
||||||
|
|
||||||
def make_assert(error):
|
|
||||||
"""
|
|
||||||
Create an assert function that uses :func:`exception_from_error_queue` to
|
|
||||||
raise an exception wrapped by *error*.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def openssl_assert(ok):
|
|
||||||
"""
|
|
||||||
If *ok* is not True, retrieve the error from OpenSSL and raise it.
|
|
||||||
"""
|
|
||||||
if ok is not True:
|
|
||||||
exception_from_error_queue(error)
|
|
||||||
|
|
||||||
return openssl_assert
|
|
||||||
|
|
||||||
|
|
||||||
def path_bytes(s):
|
|
||||||
"""
|
|
||||||
Convert a Python path to a :py:class:`bytes` for the path which can be
|
|
||||||
passed into an OpenSSL API accepting a filename.
|
|
||||||
|
|
||||||
:param s: A path (valid for os.fspath).
|
|
||||||
|
|
||||||
:return: An instance of :py:class:`bytes`.
|
|
||||||
"""
|
|
||||||
b = os.fspath(s)
|
|
||||||
|
|
||||||
if isinstance(b, str):
|
|
||||||
return b.encode(sys.getfilesystemencoding())
|
|
||||||
else:
|
|
||||||
return b
|
|
||||||
|
|
||||||
|
|
||||||
def byte_string(s):
|
|
||||||
return s.encode("charmap")
|
|
||||||
|
|
||||||
|
|
||||||
# A marker object to observe whether some optional arguments are passed any
|
|
||||||
# value or not.
|
|
||||||
UNSPECIFIED = object()
|
|
||||||
|
|
||||||
_TEXT_WARNING = "str for {0} is no longer accepted, use bytes"
|
|
||||||
|
|
||||||
|
|
||||||
def text_to_bytes_and_warn(label, obj):
|
|
||||||
"""
|
|
||||||
If ``obj`` is text, emit a warning that it should be bytes instead and try
|
|
||||||
to convert it to bytes automatically.
|
|
||||||
|
|
||||||
:param str label: The name of the parameter from which ``obj`` was taken
|
|
||||||
(so a developer can easily find the source of the problem and correct
|
|
||||||
it).
|
|
||||||
|
|
||||||
:return: If ``obj`` is the text string type, a ``bytes`` object giving the
|
|
||||||
UTF-8 encoding of that text is returned. Otherwise, ``obj`` itself is
|
|
||||||
returned.
|
|
||||||
"""
|
|
||||||
if isinstance(obj, str):
|
|
||||||
warnings.warn(
|
|
||||||
_TEXT_WARNING.format(label),
|
|
||||||
category=DeprecationWarning,
|
|
||||||
stacklevel=3,
|
|
||||||
)
|
|
||||||
return obj.encode("utf-8")
|
|
||||||
return obj
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,42 +0,0 @@
|
|||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import ssl
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import OpenSSL.SSL
|
|
||||||
import cffi
|
|
||||||
import cryptography
|
|
||||||
|
|
||||||
from . import version
|
|
||||||
|
|
||||||
|
|
||||||
_env_info = """\
|
|
||||||
pyOpenSSL: {pyopenssl}
|
|
||||||
cryptography: {cryptography}
|
|
||||||
cffi: {cffi}
|
|
||||||
cryptography's compiled against OpenSSL: {crypto_openssl_compile}
|
|
||||||
cryptography's linked OpenSSL: {crypto_openssl_link}
|
|
||||||
Python's OpenSSL: {python_openssl}
|
|
||||||
Python executable: {python}
|
|
||||||
Python version: {python_version}
|
|
||||||
Platform: {platform}
|
|
||||||
sys.path: {sys_path}""".format(
|
|
||||||
pyopenssl=version.__version__,
|
|
||||||
crypto_openssl_compile=OpenSSL._util.ffi.string(
|
|
||||||
OpenSSL._util.lib.OPENSSL_VERSION_TEXT,
|
|
||||||
).decode("ascii"),
|
|
||||||
crypto_openssl_link=OpenSSL.SSL.SSLeay_version(
|
|
||||||
OpenSSL.SSL.SSLEAY_VERSION
|
|
||||||
).decode("ascii"),
|
|
||||||
python_openssl=getattr(ssl, "OPENSSL_VERSION", "n/a"),
|
|
||||||
cryptography=cryptography.__version__,
|
|
||||||
cffi=cffi.__version__,
|
|
||||||
python=sys.executable,
|
|
||||||
python_version=sys.version,
|
|
||||||
platform=sys.platform,
|
|
||||||
sys_path=sys.path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(_env_info)
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
"""
|
|
||||||
PRNG management routines, thin wrappers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from OpenSSL._util import lib as _lib
|
|
||||||
|
|
||||||
|
|
||||||
def add(buffer, entropy):
|
|
||||||
"""
|
|
||||||
Mix bytes from *string* into the PRNG state.
|
|
||||||
|
|
||||||
The *entropy* argument is (the lower bound of) an estimate of how much
|
|
||||||
randomness is contained in *string*, measured in bytes.
|
|
||||||
|
|
||||||
For more information, see e.g. :rfc:`1750`.
|
|
||||||
|
|
||||||
This function is only relevant if you are forking Python processes and
|
|
||||||
need to reseed the CSPRNG after fork.
|
|
||||||
|
|
||||||
:param buffer: Buffer with random data.
|
|
||||||
:param entropy: The entropy (in bytes) measurement of the buffer.
|
|
||||||
|
|
||||||
:return: :obj:`None`
|
|
||||||
"""
|
|
||||||
if not isinstance(buffer, bytes):
|
|
||||||
raise TypeError("buffer must be a byte string")
|
|
||||||
|
|
||||||
if not isinstance(entropy, int):
|
|
||||||
raise TypeError("entropy must be an integer")
|
|
||||||
|
|
||||||
_lib.RAND_add(buffer, len(buffer), entropy)
|
|
||||||
|
|
||||||
|
|
||||||
def status():
|
|
||||||
"""
|
|
||||||
Check whether the PRNG has been seeded with enough data.
|
|
||||||
|
|
||||||
:return: 1 if the PRNG is seeded enough, 0 otherwise.
|
|
||||||
"""
|
|
||||||
return _lib.RAND_status()
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
# Copyright (C) AB Strakt
|
|
||||||
# Copyright (C) Jean-Paul Calderone
|
|
||||||
# See LICENSE for details.
|
|
||||||
|
|
||||||
"""
|
|
||||||
pyOpenSSL - A simple wrapper around the OpenSSL library
|
|
||||||
"""
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"__author__",
|
|
||||||
"__copyright__",
|
|
||||||
"__email__",
|
|
||||||
"__license__",
|
|
||||||
"__summary__",
|
|
||||||
"__title__",
|
|
||||||
"__uri__",
|
|
||||||
"__version__",
|
|
||||||
]
|
|
||||||
|
|
||||||
__version__ = "22.0.0"
|
|
||||||
|
|
||||||
__title__ = "pyOpenSSL"
|
|
||||||
__uri__ = "https://pyopenssl.org/"
|
|
||||||
__summary__ = "Python wrapper module around the OpenSSL library"
|
|
||||||
__author__ = "The pyOpenSSL developers"
|
|
||||||
__email__ = "cryptography-dev@python.org"
|
|
||||||
__license__ = "Apache License, Version 2.0"
|
|
||||||
__copyright__ = "Copyright 2001-2020 {0}".format(__author__)
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import importlib
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
is_pypy = '__pypy__' in sys.builtin_module_names
|
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore',
|
|
||||||
r'.+ distutils\b.+ deprecated',
|
|
||||||
DeprecationWarning)
|
|
||||||
|
|
||||||
|
|
||||||
def warn_distutils_present():
|
|
||||||
if 'distutils' not in sys.modules:
|
|
||||||
return
|
|
||||||
if is_pypy and sys.version_info < (3, 7):
|
|
||||||
# PyPy for 3.6 unconditionally imports distutils, so bypass the warning
|
|
||||||
# https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250
|
|
||||||
return
|
|
||||||
warnings.warn(
|
|
||||||
"Distutils was imported before Setuptools, but importing Setuptools "
|
|
||||||
"also replaces the `distutils` module in `sys.modules`. This may lead "
|
|
||||||
"to undesirable behaviors or errors. To avoid these issues, avoid "
|
|
||||||
"using distutils directly, ensure that setuptools is installed in the "
|
|
||||||
"traditional way (e.g. not an editable install), and/or make sure "
|
|
||||||
"that setuptools is always imported before distutils.")
|
|
||||||
|
|
||||||
|
|
||||||
def clear_distutils():
|
|
||||||
if 'distutils' not in sys.modules:
|
|
||||||
return
|
|
||||||
warnings.warn("Setuptools is replacing distutils.")
|
|
||||||
mods = [name for name in sys.modules if re.match(r'distutils\b', name)]
|
|
||||||
for name in mods:
|
|
||||||
del sys.modules[name]
|
|
||||||
|
|
||||||
|
|
||||||
def enabled():
|
|
||||||
"""
|
|
||||||
Allow selection of distutils by environment variable.
|
|
||||||
"""
|
|
||||||
which = os.environ.get('SETUPTOOLS_USE_DISTUTILS', 'stdlib')
|
|
||||||
return which == 'local'
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_local_distutils():
|
|
||||||
clear_distutils()
|
|
||||||
distutils = importlib.import_module('setuptools._distutils')
|
|
||||||
distutils.__name__ = 'distutils'
|
|
||||||
sys.modules['distutils'] = distutils
|
|
||||||
|
|
||||||
# sanity check that submodules load as expected
|
|
||||||
core = importlib.import_module('distutils.core')
|
|
||||||
assert '_distutils' in core.__file__, core.__file__
|
|
||||||
|
|
||||||
|
|
||||||
def do_override():
|
|
||||||
"""
|
|
||||||
Ensure that the local copy of distutils is preferred over stdlib.
|
|
||||||
|
|
||||||
See https://github.com/pypa/setuptools/issues/417#issuecomment-392298401
|
|
||||||
for more motivation.
|
|
||||||
"""
|
|
||||||
if enabled():
|
|
||||||
warn_distutils_present()
|
|
||||||
ensure_local_distutils()
|
|
||||||
|
|
||||||
|
|
||||||
class DistutilsMetaFinder:
|
|
||||||
def find_spec(self, fullname, path, target=None):
|
|
||||||
if path is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
method_name = 'spec_for_{fullname}'.format(**locals())
|
|
||||||
method = getattr(self, method_name, lambda: None)
|
|
||||||
return method()
|
|
||||||
|
|
||||||
def spec_for_distutils(self):
|
|
||||||
import importlib.abc
|
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
class DistutilsLoader(importlib.abc.Loader):
|
|
||||||
|
|
||||||
def create_module(self, spec):
|
|
||||||
return importlib.import_module('setuptools._distutils')
|
|
||||||
|
|
||||||
def exec_module(self, module):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return importlib.util.spec_from_loader('distutils', DistutilsLoader())
|
|
||||||
|
|
||||||
def spec_for_pip(self):
|
|
||||||
"""
|
|
||||||
Ensure stdlib distutils when running under pip.
|
|
||||||
See pypa/pip#8761 for rationale.
|
|
||||||
"""
|
|
||||||
if self.pip_imported_during_build():
|
|
||||||
return
|
|
||||||
clear_distutils()
|
|
||||||
self.spec_for_distutils = lambda: None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def pip_imported_during_build():
|
|
||||||
"""
|
|
||||||
Detect if pip is being imported in a build script. Ref #2355.
|
|
||||||
"""
|
|
||||||
import traceback
|
|
||||||
return any(
|
|
||||||
frame.f_globals['__file__'].endswith('setup.py')
|
|
||||||
for frame, line in traceback.walk_stack(None)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DISTUTILS_FINDER = DistutilsMetaFinder()
|
|
||||||
|
|
||||||
|
|
||||||
def add_shim():
|
|
||||||
sys.meta_path.insert(0, DISTUTILS_FINDER)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_shim():
|
|
||||||
try:
|
|
||||||
sys.meta_path.remove(DISTUTILS_FINDER)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
__import__('_distutils_hack').do_override()
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
from ._version import __version__
|
|
||||||
from ._impl import (
|
|
||||||
async_generator,
|
|
||||||
yield_,
|
|
||||||
yield_from_,
|
|
||||||
isasyncgen,
|
|
||||||
isasyncgenfunction,
|
|
||||||
get_asyncgen_hooks,
|
|
||||||
set_asyncgen_hooks,
|
|
||||||
)
|
|
||||||
from ._util import aclosing, asynccontextmanager
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"async_generator",
|
|
||||||
"yield_",
|
|
||||||
"yield_from_",
|
|
||||||
"aclosing",
|
|
||||||
"isasyncgen",
|
|
||||||
"isasyncgenfunction",
|
|
||||||
"asynccontextmanager",
|
|
||||||
"get_asyncgen_hooks",
|
|
||||||
"set_asyncgen_hooks",
|
|
||||||
]
|
|
||||||
@@ -1,455 +0,0 @@
|
|||||||
import sys
|
|
||||||
from functools import wraps
|
|
||||||
from types import coroutine
|
|
||||||
import inspect
|
|
||||||
from inspect import (
|
|
||||||
getcoroutinestate, CORO_CREATED, CORO_CLOSED, CORO_SUSPENDED
|
|
||||||
)
|
|
||||||
import collections.abc
|
|
||||||
|
|
||||||
|
|
||||||
class YieldWrapper:
|
|
||||||
def __init__(self, payload):
|
|
||||||
self.payload = payload
|
|
||||||
|
|
||||||
|
|
||||||
def _wrap(value):
|
|
||||||
return YieldWrapper(value)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_wrapped(box):
|
|
||||||
return isinstance(box, YieldWrapper)
|
|
||||||
|
|
||||||
|
|
||||||
def _unwrap(box):
|
|
||||||
return box.payload
|
|
||||||
|
|
||||||
|
|
||||||
# This is the magic code that lets you use yield_ and yield_from_ with native
|
|
||||||
# generators.
|
|
||||||
#
|
|
||||||
# The old version worked great on Linux and MacOS, but not on Windows, because
|
|
||||||
# it depended on _PyAsyncGenValueWrapperNew. The new version segfaults
|
|
||||||
# everywhere, and I'm not sure why -- probably my lack of understanding
|
|
||||||
# of ctypes and refcounts.
|
|
||||||
#
|
|
||||||
# There are also some commented out tests that should be re-enabled if this is
|
|
||||||
# fixed:
|
|
||||||
#
|
|
||||||
# if sys.version_info >= (3, 6):
|
|
||||||
# # Use the same box type that the interpreter uses internally. This allows
|
|
||||||
# # yield_ and (more importantly!) yield_from_ to work in built-in
|
|
||||||
# # generators.
|
|
||||||
# import ctypes # mua ha ha.
|
|
||||||
#
|
|
||||||
# # We used to call _PyAsyncGenValueWrapperNew to create and set up new
|
|
||||||
# # wrapper objects, but that symbol isn't available on Windows:
|
|
||||||
# #
|
|
||||||
# # https://github.com/python-trio/async_generator/issues/5
|
|
||||||
# #
|
|
||||||
# # Fortunately, the type object is available, but it means we have to do
|
|
||||||
# # this the hard way.
|
|
||||||
#
|
|
||||||
# # We don't actually need to access this, but we need to make a ctypes
|
|
||||||
# # structure so we can call addressof.
|
|
||||||
# class _ctypes_PyTypeObject(ctypes.Structure):
|
|
||||||
# pass
|
|
||||||
# _PyAsyncGenWrappedValue_Type_ptr = ctypes.addressof(
|
|
||||||
# _ctypes_PyTypeObject.in_dll(
|
|
||||||
# ctypes.pythonapi, "_PyAsyncGenWrappedValue_Type"))
|
|
||||||
# _PyObject_GC_New = ctypes.pythonapi._PyObject_GC_New
|
|
||||||
# _PyObject_GC_New.restype = ctypes.py_object
|
|
||||||
# _PyObject_GC_New.argtypes = (ctypes.c_void_p,)
|
|
||||||
#
|
|
||||||
# _Py_IncRef = ctypes.pythonapi.Py_IncRef
|
|
||||||
# _Py_IncRef.restype = None
|
|
||||||
# _Py_IncRef.argtypes = (ctypes.py_object,)
|
|
||||||
#
|
|
||||||
# class _ctypes_PyAsyncGenWrappedValue(ctypes.Structure):
|
|
||||||
# _fields_ = [
|
|
||||||
# ('PyObject_HEAD', ctypes.c_byte * object().__sizeof__()),
|
|
||||||
# ('agw_val', ctypes.py_object),
|
|
||||||
# ]
|
|
||||||
# def _wrap(value):
|
|
||||||
# box = _PyObject_GC_New(_PyAsyncGenWrappedValue_Type_ptr)
|
|
||||||
# raw = ctypes.cast(ctypes.c_void_p(id(box)),
|
|
||||||
# ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue))
|
|
||||||
# raw.contents.agw_val = value
|
|
||||||
# _Py_IncRef(value)
|
|
||||||
# return box
|
|
||||||
#
|
|
||||||
# def _unwrap(box):
|
|
||||||
# assert _is_wrapped(box)
|
|
||||||
# raw = ctypes.cast(ctypes.c_void_p(id(box)),
|
|
||||||
# ctypes.POINTER(_ctypes_PyAsyncGenWrappedValue))
|
|
||||||
# value = raw.contents.agw_val
|
|
||||||
# _Py_IncRef(value)
|
|
||||||
# return value
|
|
||||||
#
|
|
||||||
# _PyAsyncGenWrappedValue_Type = type(_wrap(1))
|
|
||||||
# def _is_wrapped(box):
|
|
||||||
# return isinstance(box, _PyAsyncGenWrappedValue_Type)
|
|
||||||
|
|
||||||
|
|
||||||
# The magic @coroutine decorator is how you write the bottom level of
|
|
||||||
# coroutine stacks -- 'async def' can only use 'await' = yield from; but
|
|
||||||
# eventually we must bottom out in a @coroutine that calls plain 'yield'.
|
|
||||||
@coroutine
|
|
||||||
def _yield_(value):
|
|
||||||
return (yield _wrap(value))
|
|
||||||
|
|
||||||
|
|
||||||
# But we wrap the bare @coroutine version in an async def, because async def
|
|
||||||
# has the magic feature that users can get warnings messages if they forget to
|
|
||||||
# use 'await'.
|
|
||||||
async def yield_(value=None):
|
|
||||||
return await _yield_(value)
|
|
||||||
|
|
||||||
|
|
||||||
async def yield_from_(delegate):
|
|
||||||
# Transcribed with adaptations from:
|
|
||||||
#
|
|
||||||
# https://www.python.org/dev/peps/pep-0380/#formal-semantics
|
|
||||||
#
|
|
||||||
# This takes advantage of a sneaky trick: if an @async_generator-wrapped
|
|
||||||
# function calls another async function (like yield_from_), and that
|
|
||||||
# second async function calls yield_, then because of the hack we use to
|
|
||||||
# implement yield_, the yield_ will actually propagate through yield_from_
|
|
||||||
# back to the @async_generator wrapper. So even though we're a regular
|
|
||||||
# function, we can directly yield values out of the calling async
|
|
||||||
# generator.
|
|
||||||
def unpack_StopAsyncIteration(e):
|
|
||||||
if e.args:
|
|
||||||
return e.args[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
_i = type(delegate).__aiter__(delegate)
|
|
||||||
if hasattr(_i, "__await__"):
|
|
||||||
_i = await _i
|
|
||||||
try:
|
|
||||||
_y = await type(_i).__anext__(_i)
|
|
||||||
except StopAsyncIteration as _e:
|
|
||||||
_r = unpack_StopAsyncIteration(_e)
|
|
||||||
else:
|
|
||||||
while 1:
|
|
||||||
try:
|
|
||||||
_s = await yield_(_y)
|
|
||||||
except GeneratorExit as _e:
|
|
||||||
try:
|
|
||||||
_m = _i.aclose
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
await _m()
|
|
||||||
raise _e
|
|
||||||
except BaseException as _e:
|
|
||||||
_x = sys.exc_info()
|
|
||||||
try:
|
|
||||||
_m = _i.athrow
|
|
||||||
except AttributeError:
|
|
||||||
raise _e
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
_y = await _m(*_x)
|
|
||||||
except StopAsyncIteration as _e:
|
|
||||||
_r = unpack_StopAsyncIteration(_e)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
if _s is None:
|
|
||||||
_y = await type(_i).__anext__(_i)
|
|
||||||
else:
|
|
||||||
_y = await _i.asend(_s)
|
|
||||||
except StopAsyncIteration as _e:
|
|
||||||
_r = unpack_StopAsyncIteration(_e)
|
|
||||||
break
|
|
||||||
return _r
|
|
||||||
|
|
||||||
|
|
||||||
# This is the awaitable / iterator that implements asynciter.__anext__() and
|
|
||||||
# friends.
|
|
||||||
#
|
|
||||||
# Note: we can be sloppy about the distinction between
|
|
||||||
#
|
|
||||||
# type(self._it).__next__(self._it)
|
|
||||||
#
|
|
||||||
# and
|
|
||||||
#
|
|
||||||
# self._it.__next__()
|
|
||||||
#
|
|
||||||
# because we happen to know that self._it is not a general iterator object,
|
|
||||||
# but specifically a coroutine iterator object where these are equivalent.
|
|
||||||
class ANextIter:
|
|
||||||
def __init__(self, it, first_fn, *first_args):
|
|
||||||
self._it = it
|
|
||||||
self._first_fn = first_fn
|
|
||||||
self._first_args = first_args
|
|
||||||
|
|
||||||
def __await__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self._first_fn is not None:
|
|
||||||
first_fn = self._first_fn
|
|
||||||
first_args = self._first_args
|
|
||||||
self._first_fn = self._first_args = None
|
|
||||||
return self._invoke(first_fn, *first_args)
|
|
||||||
else:
|
|
||||||
return self._invoke(self._it.__next__)
|
|
||||||
|
|
||||||
def send(self, value):
|
|
||||||
return self._invoke(self._it.send, value)
|
|
||||||
|
|
||||||
def throw(self, type, value=None, traceback=None):
|
|
||||||
return self._invoke(self._it.throw, type, value, traceback)
|
|
||||||
|
|
||||||
def _invoke(self, fn, *args):
|
|
||||||
try:
|
|
||||||
result = fn(*args)
|
|
||||||
except StopIteration as e:
|
|
||||||
# The underlying generator returned, so we should signal the end
|
|
||||||
# of iteration.
|
|
||||||
raise StopAsyncIteration(e.value)
|
|
||||||
except StopAsyncIteration as e:
|
|
||||||
# PEP 479 says: if a generator raises Stop(Async)Iteration, then
|
|
||||||
# it should be wrapped into a RuntimeError. Python automatically
|
|
||||||
# enforces this for StopIteration; for StopAsyncIteration we need
|
|
||||||
# to it ourselves.
|
|
||||||
raise RuntimeError(
|
|
||||||
"async_generator raise StopAsyncIteration"
|
|
||||||
) from e
|
|
||||||
if _is_wrapped(result):
|
|
||||||
raise StopIteration(_unwrap(result))
|
|
||||||
else:
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
UNSPECIFIED = object()
|
|
||||||
try:
|
|
||||||
from sys import get_asyncgen_hooks, set_asyncgen_hooks
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
import threading
|
|
||||||
|
|
||||||
asyncgen_hooks = collections.namedtuple(
|
|
||||||
"asyncgen_hooks", ("firstiter", "finalizer")
|
|
||||||
)
|
|
||||||
|
|
||||||
class _hooks_storage(threading.local):
|
|
||||||
def __init__(self):
|
|
||||||
self.firstiter = None
|
|
||||||
self.finalizer = None
|
|
||||||
|
|
||||||
_hooks = _hooks_storage()
|
|
||||||
|
|
||||||
def get_asyncgen_hooks():
|
|
||||||
return asyncgen_hooks(
|
|
||||||
firstiter=_hooks.firstiter, finalizer=_hooks.finalizer
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_asyncgen_hooks(firstiter=UNSPECIFIED, finalizer=UNSPECIFIED):
|
|
||||||
if firstiter is not UNSPECIFIED:
|
|
||||||
if firstiter is None or callable(firstiter):
|
|
||||||
_hooks.firstiter = firstiter
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
"callable firstiter expected, got {}".format(
|
|
||||||
type(firstiter).__name__
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if finalizer is not UNSPECIFIED:
|
|
||||||
if finalizer is None or callable(finalizer):
|
|
||||||
_hooks.finalizer = finalizer
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
"callable finalizer expected, got {}".format(
|
|
||||||
type(finalizer).__name__
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncGenerator:
|
|
||||||
# https://bitbucket.org/pypy/pypy/issues/2786:
|
|
||||||
# PyPy implements 'await' in a way that requires the frame object
|
|
||||||
# used to execute a coroutine to keep a weakref to that coroutine.
|
|
||||||
# During a GC pass, weakrefs to all doomed objects are broken
|
|
||||||
# before any of the doomed objects' finalizers are invoked.
|
|
||||||
# If an AsyncGenerator is unreachable, its _coroutine probably
|
|
||||||
# is too, and the weakref from ag._coroutine.cr_frame to
|
|
||||||
# ag._coroutine will be broken before ag.__del__ can do its
|
|
||||||
# one-turn close attempt or can schedule a full aclose() using
|
|
||||||
# the registered finalization hook. It doesn't look like the
|
|
||||||
# underlying issue is likely to be fully fixed anytime soon,
|
|
||||||
# so we work around it by preventing an AsyncGenerator and
|
|
||||||
# its _coroutine from being considered newly unreachable at
|
|
||||||
# the same time if the AsyncGenerator's finalizer might want
|
|
||||||
# to iterate the coroutine some more.
|
|
||||||
_pypy_issue2786_workaround = set()
|
|
||||||
|
|
||||||
def __init__(self, coroutine):
|
|
||||||
self._coroutine = coroutine
|
|
||||||
self._it = coroutine.__await__()
|
|
||||||
self.ag_running = False
|
|
||||||
self._finalizer = None
|
|
||||||
self._closed = False
|
|
||||||
self._hooks_inited = False
|
|
||||||
|
|
||||||
# On python 3.5.0 and 3.5.1, __aiter__ must be awaitable.
|
|
||||||
# Starting in 3.5.2, it should not be awaitable, and if it is, then it
|
|
||||||
# raises a PendingDeprecationWarning.
|
|
||||||
# See:
|
|
||||||
# https://www.python.org/dev/peps/pep-0492/#api-design-and-implementation-revisions
|
|
||||||
# https://docs.python.org/3/reference/datamodel.html#async-iterators
|
|
||||||
# https://bugs.python.org/issue27243
|
|
||||||
if sys.version_info < (3, 5, 2):
|
|
||||||
|
|
||||||
async def __aiter__(self):
|
|
||||||
return self
|
|
||||||
else:
|
|
||||||
|
|
||||||
def __aiter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
################################################################
|
|
||||||
# Introspection attributes
|
|
||||||
################################################################
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ag_code(self):
|
|
||||||
return self._coroutine.cr_code
|
|
||||||
|
|
||||||
@property
|
|
||||||
def ag_frame(self):
|
|
||||||
return self._coroutine.cr_frame
|
|
||||||
|
|
||||||
################################################################
|
|
||||||
# Core functionality
|
|
||||||
################################################################
|
|
||||||
|
|
||||||
# These need to return awaitables, rather than being async functions,
|
|
||||||
# to match the native behavior where the firstiter hook is called
|
|
||||||
# immediately on asend()/etc, even if the coroutine that asend()
|
|
||||||
# produces isn't awaited for a bit.
|
|
||||||
|
|
||||||
def __anext__(self):
|
|
||||||
return self._do_it(self._it.__next__)
|
|
||||||
|
|
||||||
def asend(self, value):
|
|
||||||
return self._do_it(self._it.send, value)
|
|
||||||
|
|
||||||
def athrow(self, type, value=None, traceback=None):
|
|
||||||
return self._do_it(self._it.throw, type, value, traceback)
|
|
||||||
|
|
||||||
def _do_it(self, start_fn, *args):
|
|
||||||
if not self._hooks_inited:
|
|
||||||
self._hooks_inited = True
|
|
||||||
(firstiter, self._finalizer) = get_asyncgen_hooks()
|
|
||||||
if firstiter is not None:
|
|
||||||
firstiter(self)
|
|
||||||
if sys.implementation.name == "pypy":
|
|
||||||
self._pypy_issue2786_workaround.add(self._coroutine)
|
|
||||||
|
|
||||||
# On CPython 3.5.2 (but not 3.5.0), coroutines get cranky if you try
|
|
||||||
# to iterate them after they're exhausted. Generators OTOH just raise
|
|
||||||
# StopIteration. We want to convert the one into the other, so we need
|
|
||||||
# to avoid iterating stopped coroutines.
|
|
||||||
if getcoroutinestate(self._coroutine) is CORO_CLOSED:
|
|
||||||
raise StopAsyncIteration()
|
|
||||||
|
|
||||||
async def step():
|
|
||||||
if self.ag_running:
|
|
||||||
raise ValueError("async generator already executing")
|
|
||||||
try:
|
|
||||||
self.ag_running = True
|
|
||||||
return await ANextIter(self._it, start_fn, *args)
|
|
||||||
except StopAsyncIteration:
|
|
||||||
self._pypy_issue2786_workaround.discard(self._coroutine)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self.ag_running = False
|
|
||||||
|
|
||||||
return step()
|
|
||||||
|
|
||||||
################################################################
|
|
||||||
# Cleanup
|
|
||||||
################################################################
|
|
||||||
|
|
||||||
async def aclose(self):
|
|
||||||
state = getcoroutinestate(self._coroutine)
|
|
||||||
if state is CORO_CLOSED or self._closed:
|
|
||||||
return
|
|
||||||
# Make sure that even if we raise "async_generator ignored
|
|
||||||
# GeneratorExit", and thus fail to exhaust the coroutine,
|
|
||||||
# __del__ doesn't complain again.
|
|
||||||
self._closed = True
|
|
||||||
if state is CORO_CREATED:
|
|
||||||
# Make sure that aclose() on an unstarted generator returns
|
|
||||||
# successfully and prevents future iteration.
|
|
||||||
self._it.close()
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
await self.athrow(GeneratorExit)
|
|
||||||
except (GeneratorExit, StopAsyncIteration):
|
|
||||||
self._pypy_issue2786_workaround.discard(self._coroutine)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("async_generator ignored GeneratorExit")
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self._pypy_issue2786_workaround.discard(self._coroutine)
|
|
||||||
if getcoroutinestate(self._coroutine) is CORO_CREATED:
|
|
||||||
# Never started, nothing to clean up, just suppress the "coroutine
|
|
||||||
# never awaited" message.
|
|
||||||
self._coroutine.close()
|
|
||||||
if getcoroutinestate(self._coroutine
|
|
||||||
) is CORO_SUSPENDED and not self._closed:
|
|
||||||
if self._finalizer is not None:
|
|
||||||
self._finalizer(self)
|
|
||||||
else:
|
|
||||||
# Mimic the behavior of native generators on GC with no finalizer:
|
|
||||||
# throw in GeneratorExit, run for one turn, and complain if it didn't
|
|
||||||
# finish.
|
|
||||||
thrower = self.athrow(GeneratorExit)
|
|
||||||
try:
|
|
||||||
thrower.send(None)
|
|
||||||
except (GeneratorExit, StopAsyncIteration):
|
|
||||||
pass
|
|
||||||
except StopIteration:
|
|
||||||
raise RuntimeError("async_generator ignored GeneratorExit")
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"async_generator {!r} awaited during finalization; install "
|
|
||||||
"a finalization hook to support this, or wrap it in "
|
|
||||||
"'async with aclosing(...):'"
|
|
||||||
.format(self.ag_code.co_name)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
thrower.close()
|
|
||||||
|
|
||||||
|
|
||||||
if hasattr(collections.abc, "AsyncGenerator"):
|
|
||||||
collections.abc.AsyncGenerator.register(AsyncGenerator)
|
|
||||||
|
|
||||||
|
|
||||||
def async_generator(coroutine_maker):
|
|
||||||
@wraps(coroutine_maker)
|
|
||||||
def async_generator_maker(*args, **kwargs):
|
|
||||||
return AsyncGenerator(coroutine_maker(*args, **kwargs))
|
|
||||||
|
|
||||||
async_generator_maker._async_gen_function = id(async_generator_maker)
|
|
||||||
return async_generator_maker
|
|
||||||
|
|
||||||
|
|
||||||
def isasyncgen(obj):
|
|
||||||
if hasattr(inspect, "isasyncgen"):
|
|
||||||
if inspect.isasyncgen(obj):
|
|
||||||
return True
|
|
||||||
return isinstance(obj, AsyncGenerator)
|
|
||||||
|
|
||||||
|
|
||||||
def isasyncgenfunction(obj):
|
|
||||||
if hasattr(inspect, "isasyncgenfunction"):
|
|
||||||
if inspect.isasyncgenfunction(obj):
|
|
||||||
return True
|
|
||||||
return getattr(obj, "_async_gen_function", -1) == id(obj)
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
import pytest
|
|
||||||
from functools import wraps, partial
|
|
||||||
import inspect
|
|
||||||
import types
|
|
||||||
|
|
||||||
|
|
||||||
@types.coroutine
|
|
||||||
def mock_sleep():
|
|
||||||
yield "mock_sleep"
|
|
||||||
|
|
||||||
|
|
||||||
# Wrap any 'async def' tests so that they get automatically iterated.
|
|
||||||
# We used to use pytest-asyncio as a convenient way to do this, but nowadays
|
|
||||||
# pytest-asyncio uses us! In addition to it being generally bad for our test
|
|
||||||
# infrastructure to depend on the code-under-test, this totally messes up
|
|
||||||
# coverage info because depending on pytest's plugin load order, we might get
|
|
||||||
# imported before pytest-cov can be loaded and start gathering coverage.
|
|
||||||
@pytest.hookimpl(tryfirst=True)
|
|
||||||
def pytest_pyfunc_call(pyfuncitem):
|
|
||||||
if inspect.iscoroutinefunction(pyfuncitem.obj):
|
|
||||||
fn = pyfuncitem.obj
|
|
||||||
|
|
||||||
@wraps(fn)
|
|
||||||
def wrapper(**kwargs):
|
|
||||||
coro = fn(**kwargs)
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
value = coro.send(None)
|
|
||||||
if value != "mock_sleep": # pragma: no cover
|
|
||||||
raise RuntimeError(
|
|
||||||
"coroutine runner confused: {!r}".format(value)
|
|
||||||
)
|
|
||||||
except StopIteration:
|
|
||||||
pass
|
|
||||||
|
|
||||||
pyfuncitem.obj = wrapper
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,227 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from .. import aclosing, async_generator, yield_, asynccontextmanager
|
|
||||||
|
|
||||||
|
|
||||||
@async_generator
|
|
||||||
async def async_range(count, closed_slot):
|
|
||||||
try:
|
|
||||||
for i in range(count): # pragma: no branch
|
|
||||||
await yield_(i)
|
|
||||||
except GeneratorExit:
|
|
||||||
closed_slot[0] = True
|
|
||||||
|
|
||||||
|
|
||||||
async def test_aclosing():
|
|
||||||
closed_slot = [False]
|
|
||||||
async with aclosing(async_range(10, closed_slot)) as gen:
|
|
||||||
it = iter(range(10))
|
|
||||||
async for item in gen: # pragma: no branch
|
|
||||||
assert item == next(it)
|
|
||||||
if item == 4:
|
|
||||||
break
|
|
||||||
assert closed_slot[0]
|
|
||||||
|
|
||||||
closed_slot = [False]
|
|
||||||
try:
|
|
||||||
async with aclosing(async_range(10, closed_slot)) as gen:
|
|
||||||
it = iter(range(10))
|
|
||||||
async for item in gen: # pragma: no branch
|
|
||||||
assert item == next(it)
|
|
||||||
if item == 4:
|
|
||||||
raise ValueError()
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
assert closed_slot[0]
|
|
||||||
|
|
||||||
|
|
||||||
async def test_contextmanager_do_not_unchain_non_stopiteration_exceptions():
|
|
||||||
@asynccontextmanager
|
|
||||||
@async_generator
|
|
||||||
async def manager_issue29692():
|
|
||||||
try:
|
|
||||||
await yield_()
|
|
||||||
except Exception as exc:
|
|
||||||
raise RuntimeError('issue29692:Chained') from exc
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as excinfo:
|
|
||||||
async with manager_issue29692():
|
|
||||||
raise ZeroDivisionError
|
|
||||||
assert excinfo.value.args[0] == 'issue29692:Chained'
|
|
||||||
assert isinstance(excinfo.value.__cause__, ZeroDivisionError)
|
|
||||||
|
|
||||||
# This is a little funky because of implementation details in
|
|
||||||
# async_generator It can all go away once we stop supporting Python3.5
|
|
||||||
with pytest.raises(RuntimeError) as excinfo:
|
|
||||||
async with manager_issue29692():
|
|
||||||
exc = StopIteration('issue29692:Unchained')
|
|
||||||
raise exc
|
|
||||||
assert excinfo.value.args[0] == 'issue29692:Chained'
|
|
||||||
cause = excinfo.value.__cause__
|
|
||||||
assert cause.args[0] == 'generator raised StopIteration'
|
|
||||||
assert cause.__cause__ is exc
|
|
||||||
|
|
||||||
with pytest.raises(StopAsyncIteration) as excinfo:
|
|
||||||
async with manager_issue29692():
|
|
||||||
raise StopAsyncIteration('issue29692:Unchained')
|
|
||||||
assert excinfo.value.args[0] == 'issue29692:Unchained'
|
|
||||||
assert excinfo.value.__cause__ is None
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
@async_generator
|
|
||||||
async def noop_async_context_manager():
|
|
||||||
await yield_()
|
|
||||||
|
|
||||||
with pytest.raises(StopIteration):
|
|
||||||
async with noop_async_context_manager():
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
|
|
||||||
# Native async generators are only available from Python 3.6 and onwards
|
|
||||||
nativeasyncgenerators = True
|
|
||||||
try:
|
|
||||||
exec(
|
|
||||||
"""
|
|
||||||
@asynccontextmanager
|
|
||||||
async def manager_issue29692_2():
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
except Exception as exc:
|
|
||||||
raise RuntimeError('issue29692:Chained') from exc
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
except SyntaxError:
|
|
||||||
nativeasyncgenerators = False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not nativeasyncgenerators,
|
|
||||||
reason="Python < 3.6 doesn't have native async generators"
|
|
||||||
)
|
|
||||||
async def test_native_contextmanager_do_not_unchain_non_stopiteration_exceptions(
|
|
||||||
):
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as excinfo:
|
|
||||||
async with manager_issue29692_2():
|
|
||||||
raise ZeroDivisionError
|
|
||||||
assert excinfo.value.args[0] == 'issue29692:Chained'
|
|
||||||
assert isinstance(excinfo.value.__cause__, ZeroDivisionError)
|
|
||||||
|
|
||||||
for cls in [StopIteration, StopAsyncIteration]:
|
|
||||||
with pytest.raises(cls) as excinfo:
|
|
||||||
async with manager_issue29692_2():
|
|
||||||
raise cls('issue29692:Unchained')
|
|
||||||
assert excinfo.value.args[0] == 'issue29692:Unchained'
|
|
||||||
assert excinfo.value.__cause__ is None
|
|
||||||
|
|
||||||
|
|
||||||
async def test_asynccontextmanager_exception_passthrough():
|
|
||||||
# This was the cause of annoying coverage flapping, see gh-140
|
|
||||||
@asynccontextmanager
|
|
||||||
@async_generator
|
|
||||||
async def noop_async_context_manager():
|
|
||||||
await yield_()
|
|
||||||
|
|
||||||
for exc_type in [StopAsyncIteration, RuntimeError, ValueError]:
|
|
||||||
with pytest.raises(exc_type):
|
|
||||||
async with noop_async_context_manager():
|
|
||||||
raise exc_type
|
|
||||||
|
|
||||||
# And let's also check a boring nothing pass-through while we're at it
|
|
||||||
async with noop_async_context_manager():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def test_asynccontextmanager_catches_exception():
|
|
||||||
@asynccontextmanager
|
|
||||||
@async_generator
|
|
||||||
async def catch_it():
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await yield_()
|
|
||||||
|
|
||||||
async with catch_it():
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
|
|
||||||
async def test_asynccontextmanager_different_exception():
|
|
||||||
@asynccontextmanager
|
|
||||||
@async_generator
|
|
||||||
async def switch_it():
|
|
||||||
try:
|
|
||||||
await yield_()
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
async with switch_it():
|
|
||||||
raise KeyError
|
|
||||||
|
|
||||||
|
|
||||||
async def test_asynccontextmanager_nice_message_on_sync_enter():
|
|
||||||
@asynccontextmanager
|
|
||||||
@async_generator
|
|
||||||
async def xxx(): # pragma: no cover
|
|
||||||
await yield_()
|
|
||||||
|
|
||||||
cm = xxx()
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as excinfo:
|
|
||||||
with cm:
|
|
||||||
pass # pragma: no cover
|
|
||||||
|
|
||||||
assert "async with" in str(excinfo.value)
|
|
||||||
|
|
||||||
async with cm:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def test_asynccontextmanager_no_yield():
|
|
||||||
@asynccontextmanager
|
|
||||||
@async_generator
|
|
||||||
async def yeehaw():
|
|
||||||
pass
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as excinfo:
|
|
||||||
async with yeehaw():
|
|
||||||
assert False # pragma: no cover
|
|
||||||
|
|
||||||
assert "didn't yield" in str(excinfo.value)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_asynccontextmanager_too_many_yields():
|
|
||||||
closed_count = 0
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
@async_generator
|
|
||||||
async def doubleyield():
|
|
||||||
try:
|
|
||||||
await yield_()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
await yield_()
|
|
||||||
finally:
|
|
||||||
nonlocal closed_count
|
|
||||||
closed_count += 1
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as excinfo:
|
|
||||||
async with doubleyield():
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert "didn't stop" in str(excinfo.value)
|
|
||||||
assert closed_count == 1
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError) as excinfo:
|
|
||||||
async with doubleyield():
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
assert "didn't stop after athrow" in str(excinfo.value)
|
|
||||||
assert closed_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
async def test_asynccontextmanager_requires_asyncgenfunction():
|
|
||||||
with pytest.raises(TypeError):
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
def syncgen(): # pragma: no cover
|
|
||||||
yield
|
|
||||||
@@ -1,110 +0,0 @@
|
|||||||
import sys
|
|
||||||
from functools import wraps
|
|
||||||
from ._impl import isasyncgenfunction
|
|
||||||
|
|
||||||
|
|
||||||
class aclosing:
|
|
||||||
def __init__(self, aiter):
|
|
||||||
self._aiter = aiter
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
return self._aiter
|
|
||||||
|
|
||||||
async def __aexit__(self, *args):
|
|
||||||
await self._aiter.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
# Very much derived from the one in contextlib, by copy/pasting and then
|
|
||||||
# asyncifying everything. (Also I dropped the obscure support for using
|
|
||||||
# context managers as function decorators. It could be re-added; I just
|
|
||||||
# couldn't be bothered.)
|
|
||||||
# So this is a derivative work licensed under the PSF License, which requires
|
|
||||||
# the following notice:
|
|
||||||
#
|
|
||||||
# Copyright © 2001-2017 Python Software Foundation; All Rights Reserved
|
|
||||||
class _AsyncGeneratorContextManager:
|
|
||||||
def __init__(self, func, args, kwds):
|
|
||||||
self._func_name = func.__name__
|
|
||||||
self._agen = func(*args, **kwds).__aiter__()
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
if sys.version_info < (3, 5, 2):
|
|
||||||
self._agen = await self._agen
|
|
||||||
try:
|
|
||||||
return await self._agen.asend(None)
|
|
||||||
except StopAsyncIteration:
|
|
||||||
raise RuntimeError("async generator didn't yield") from None
|
|
||||||
|
|
||||||
async def __aexit__(self, type, value, traceback):
|
|
||||||
async with aclosing(self._agen):
|
|
||||||
if type is None:
|
|
||||||
try:
|
|
||||||
await self._agen.asend(None)
|
|
||||||
except StopAsyncIteration:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
raise RuntimeError("async generator didn't stop")
|
|
||||||
else:
|
|
||||||
# It used to be possible to have type != None, value == None:
|
|
||||||
# https://bugs.python.org/issue1705170
|
|
||||||
# but AFAICT this can't happen anymore.
|
|
||||||
assert value is not None
|
|
||||||
try:
|
|
||||||
await self._agen.athrow(type, value, traceback)
|
|
||||||
raise RuntimeError(
|
|
||||||
"async generator didn't stop after athrow()"
|
|
||||||
)
|
|
||||||
except StopAsyncIteration as exc:
|
|
||||||
# Suppress StopIteration *unless* it's the same exception
|
|
||||||
# that was passed to throw(). This prevents a
|
|
||||||
# StopIteration raised inside the "with" statement from
|
|
||||||
# being suppressed.
|
|
||||||
return (exc is not value)
|
|
||||||
except RuntimeError as exc:
|
|
||||||
# Don't re-raise the passed in exception. (issue27112)
|
|
||||||
if exc is value:
|
|
||||||
return False
|
|
||||||
# Likewise, avoid suppressing if a StopIteration exception
|
|
||||||
# was passed to throw() and later wrapped into a
|
|
||||||
# RuntimeError (see PEP 479).
|
|
||||||
if (isinstance(value, (StopIteration, StopAsyncIteration))
|
|
||||||
and exc.__cause__ is value):
|
|
||||||
return False
|
|
||||||
raise
|
|
||||||
except:
|
|
||||||
# only re-raise if it's *not* the exception that was
|
|
||||||
# passed to throw(), because __exit__() must not raise an
|
|
||||||
# exception unless __exit__() itself failed. But throw()
|
|
||||||
# has to raise the exception to signal propagation, so
|
|
||||||
# this fixes the impedance mismatch between the throw()
|
|
||||||
# protocol and the __exit__() protocol.
|
|
||||||
#
|
|
||||||
if sys.exc_info()[1] is value:
|
|
||||||
return False
|
|
||||||
raise
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
raise RuntimeError(
|
|
||||||
"use 'async with {func_name}(...)', not 'with {func_name}(...)'".
|
|
||||||
format(func_name=self._func_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __exit__(self): # pragma: no cover
|
|
||||||
assert False, """Never called, but should be defined"""
|
|
||||||
|
|
||||||
|
|
||||||
def asynccontextmanager(func):
|
|
||||||
"""Like @contextmanager, but async."""
|
|
||||||
if not isasyncgenfunction(func):
|
|
||||||
raise TypeError(
|
|
||||||
"must be an async generator (native or from async_generator; "
|
|
||||||
"if using @async_generator then @acontextmanager must be on top."
|
|
||||||
)
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def helper(*args, **kwds):
|
|
||||||
return _AsyncGeneratorContextManager(func, args, kwds)
|
|
||||||
|
|
||||||
# A hint for sphinxcontrib-trio:
|
|
||||||
helper.__returns_acontextmanager__ = True
|
|
||||||
return helper
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
__version__ = "1.10"
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from . import converters, exceptions, filters, setters, validators
|
|
||||||
from ._cmp import cmp_using
|
|
||||||
from ._config import get_run_validators, set_run_validators
|
|
||||||
from ._funcs import asdict, assoc, astuple, evolve, has, resolve_types
|
|
||||||
from ._make import (
|
|
||||||
NOTHING,
|
|
||||||
Attribute,
|
|
||||||
Factory,
|
|
||||||
attrib,
|
|
||||||
attrs,
|
|
||||||
fields,
|
|
||||||
fields_dict,
|
|
||||||
make_class,
|
|
||||||
validate,
|
|
||||||
)
|
|
||||||
from ._version_info import VersionInfo
|
|
||||||
|
|
||||||
|
|
||||||
__version__ = "21.4.0"
|
|
||||||
__version_info__ = VersionInfo._from_version_string(__version__)
|
|
||||||
|
|
||||||
__title__ = "attrs"
|
|
||||||
__description__ = "Classes Without Boilerplate"
|
|
||||||
__url__ = "https://www.attrs.org/"
|
|
||||||
__uri__ = __url__
|
|
||||||
__doc__ = __description__ + " <" + __uri__ + ">"
|
|
||||||
|
|
||||||
__author__ = "Hynek Schlawack"
|
|
||||||
__email__ = "hs@ox.cx"
|
|
||||||
|
|
||||||
__license__ = "MIT"
|
|
||||||
__copyright__ = "Copyright (c) 2015 Hynek Schlawack"
|
|
||||||
|
|
||||||
|
|
||||||
s = attributes = attrs
|
|
||||||
ib = attr = attrib
|
|
||||||
dataclass = partial(attrs, auto_attribs=True) # happy Easter ;)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Attribute",
|
|
||||||
"Factory",
|
|
||||||
"NOTHING",
|
|
||||||
"asdict",
|
|
||||||
"assoc",
|
|
||||||
"astuple",
|
|
||||||
"attr",
|
|
||||||
"attrib",
|
|
||||||
"attributes",
|
|
||||||
"attrs",
|
|
||||||
"cmp_using",
|
|
||||||
"converters",
|
|
||||||
"evolve",
|
|
||||||
"exceptions",
|
|
||||||
"fields",
|
|
||||||
"fields_dict",
|
|
||||||
"filters",
|
|
||||||
"get_run_validators",
|
|
||||||
"has",
|
|
||||||
"ib",
|
|
||||||
"make_class",
|
|
||||||
"resolve_types",
|
|
||||||
"s",
|
|
||||||
"set_run_validators",
|
|
||||||
"setters",
|
|
||||||
"validate",
|
|
||||||
"validators",
|
|
||||||
]
|
|
||||||
|
|
||||||
if sys.version_info[:2] >= (3, 6):
|
|
||||||
from ._next_gen import define, field, frozen, mutable # noqa: F401
|
|
||||||
|
|
||||||
__all__.extend(("define", "field", "frozen", "mutable"))
|
|
||||||
@@ -1,154 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import functools
|
|
||||||
|
|
||||||
from ._compat import new_class
|
|
||||||
from ._make import _make_ne
|
|
||||||
|
|
||||||
|
|
||||||
_operation_names = {"eq": "==", "lt": "<", "le": "<=", "gt": ">", "ge": ">="}
|
|
||||||
|
|
||||||
|
|
||||||
def cmp_using(
|
|
||||||
eq=None,
|
|
||||||
lt=None,
|
|
||||||
le=None,
|
|
||||||
gt=None,
|
|
||||||
ge=None,
|
|
||||||
require_same_type=True,
|
|
||||||
class_name="Comparable",
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a class that can be passed into `attr.ib`'s ``eq``, ``order``, and
|
|
||||||
``cmp`` arguments to customize field comparison.
|
|
||||||
|
|
||||||
The resulting class will have a full set of ordering methods if
|
|
||||||
at least one of ``{lt, le, gt, ge}`` and ``eq`` are provided.
|
|
||||||
|
|
||||||
:param Optional[callable] eq: `callable` used to evaluate equality
|
|
||||||
of two objects.
|
|
||||||
:param Optional[callable] lt: `callable` used to evaluate whether
|
|
||||||
one object is less than another object.
|
|
||||||
:param Optional[callable] le: `callable` used to evaluate whether
|
|
||||||
one object is less than or equal to another object.
|
|
||||||
:param Optional[callable] gt: `callable` used to evaluate whether
|
|
||||||
one object is greater than another object.
|
|
||||||
:param Optional[callable] ge: `callable` used to evaluate whether
|
|
||||||
one object is greater than or equal to another object.
|
|
||||||
|
|
||||||
:param bool require_same_type: When `True`, equality and ordering methods
|
|
||||||
will return `NotImplemented` if objects are not of the same type.
|
|
||||||
|
|
||||||
:param Optional[str] class_name: Name of class. Defaults to 'Comparable'.
|
|
||||||
|
|
||||||
See `comparison` for more details.
|
|
||||||
|
|
||||||
.. versionadded:: 21.1.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
body = {
|
|
||||||
"__slots__": ["value"],
|
|
||||||
"__init__": _make_init(),
|
|
||||||
"_requirements": [],
|
|
||||||
"_is_comparable_to": _is_comparable_to,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add operations.
|
|
||||||
num_order_functions = 0
|
|
||||||
has_eq_function = False
|
|
||||||
|
|
||||||
if eq is not None:
|
|
||||||
has_eq_function = True
|
|
||||||
body["__eq__"] = _make_operator("eq", eq)
|
|
||||||
body["__ne__"] = _make_ne()
|
|
||||||
|
|
||||||
if lt is not None:
|
|
||||||
num_order_functions += 1
|
|
||||||
body["__lt__"] = _make_operator("lt", lt)
|
|
||||||
|
|
||||||
if le is not None:
|
|
||||||
num_order_functions += 1
|
|
||||||
body["__le__"] = _make_operator("le", le)
|
|
||||||
|
|
||||||
if gt is not None:
|
|
||||||
num_order_functions += 1
|
|
||||||
body["__gt__"] = _make_operator("gt", gt)
|
|
||||||
|
|
||||||
if ge is not None:
|
|
||||||
num_order_functions += 1
|
|
||||||
body["__ge__"] = _make_operator("ge", ge)
|
|
||||||
|
|
||||||
type_ = new_class(class_name, (object,), {}, lambda ns: ns.update(body))
|
|
||||||
|
|
||||||
# Add same type requirement.
|
|
||||||
if require_same_type:
|
|
||||||
type_._requirements.append(_check_same_type)
|
|
||||||
|
|
||||||
# Add total ordering if at least one operation was defined.
|
|
||||||
if 0 < num_order_functions < 4:
|
|
||||||
if not has_eq_function:
|
|
||||||
# functools.total_ordering requires __eq__ to be defined,
|
|
||||||
# so raise early error here to keep a nice stack.
|
|
||||||
raise ValueError(
|
|
||||||
"eq must be define is order to complete ordering from "
|
|
||||||
"lt, le, gt, ge."
|
|
||||||
)
|
|
||||||
type_ = functools.total_ordering(type_)
|
|
||||||
|
|
||||||
return type_
|
|
||||||
|
|
||||||
|
|
||||||
def _make_init():
|
|
||||||
"""
|
|
||||||
Create __init__ method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, value):
|
|
||||||
"""
|
|
||||||
Initialize object with *value*.
|
|
||||||
"""
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
return __init__
|
|
||||||
|
|
||||||
|
|
||||||
def _make_operator(name, func):
|
|
||||||
"""
|
|
||||||
Create operator method.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def method(self, other):
|
|
||||||
if not self._is_comparable_to(other):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
result = func(self.value, other.value)
|
|
||||||
if result is NotImplemented:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
method.__name__ = "__%s__" % (name,)
|
|
||||||
method.__doc__ = "Return a %s b. Computed by attrs." % (
|
|
||||||
_operation_names[name],
|
|
||||||
)
|
|
||||||
|
|
||||||
return method
|
|
||||||
|
|
||||||
|
|
||||||
def _is_comparable_to(self, other):
|
|
||||||
"""
|
|
||||||
Check whether `other` is comparable to `self`.
|
|
||||||
"""
|
|
||||||
for func in self._requirements:
|
|
||||||
if not func(self, other):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _check_same_type(self, other):
|
|
||||||
"""
|
|
||||||
Return True if *self* and *other* are of the same type, False otherwise.
|
|
||||||
"""
|
|
||||||
return other.value.__class__ is self.value.__class__
|
|
||||||
@@ -1,261 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import platform
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import types
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
|
|
||||||
PY2 = sys.version_info[0] == 2
|
|
||||||
PYPY = platform.python_implementation() == "PyPy"
|
|
||||||
PY36 = sys.version_info[:2] >= (3, 6)
|
|
||||||
HAS_F_STRINGS = PY36
|
|
||||||
PY310 = sys.version_info[:2] >= (3, 10)
|
|
||||||
|
|
||||||
|
|
||||||
if PYPY or PY36:
|
|
||||||
ordered_dict = dict
|
|
||||||
else:
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
ordered_dict = OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
if PY2:
|
|
||||||
from collections import Mapping, Sequence
|
|
||||||
|
|
||||||
from UserDict import IterableUserDict
|
|
||||||
|
|
||||||
# We 'bundle' isclass instead of using inspect as importing inspect is
|
|
||||||
# fairly expensive (order of 10-15 ms for a modern machine in 2016)
|
|
||||||
def isclass(klass):
|
|
||||||
return isinstance(klass, (type, types.ClassType))
|
|
||||||
|
|
||||||
def new_class(name, bases, kwds, exec_body):
|
|
||||||
"""
|
|
||||||
A minimal stub of types.new_class that we need for make_class.
|
|
||||||
"""
|
|
||||||
ns = {}
|
|
||||||
exec_body(ns)
|
|
||||||
|
|
||||||
return type(name, bases, ns)
|
|
||||||
|
|
||||||
# TYPE is used in exceptions, repr(int) is different on Python 2 and 3.
|
|
||||||
TYPE = "type"
|
|
||||||
|
|
||||||
def iteritems(d):
|
|
||||||
return d.iteritems()
|
|
||||||
|
|
||||||
# Python 2 is bereft of a read-only dict proxy, so we make one!
|
|
||||||
class ReadOnlyDict(IterableUserDict):
|
|
||||||
"""
|
|
||||||
Best-effort read-only dict wrapper.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __setitem__(self, key, val):
|
|
||||||
# We gently pretend we're a Python 3 mappingproxy.
|
|
||||||
raise TypeError(
|
|
||||||
"'mappingproxy' object does not support item assignment"
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, _):
|
|
||||||
# We gently pretend we're a Python 3 mappingproxy.
|
|
||||||
raise AttributeError(
|
|
||||||
"'mappingproxy' object has no attribute 'update'"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __delitem__(self, _):
|
|
||||||
# We gently pretend we're a Python 3 mappingproxy.
|
|
||||||
raise TypeError(
|
|
||||||
"'mappingproxy' object does not support item deletion"
|
|
||||||
)
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
# We gently pretend we're a Python 3 mappingproxy.
|
|
||||||
raise AttributeError(
|
|
||||||
"'mappingproxy' object has no attribute 'clear'"
|
|
||||||
)
|
|
||||||
|
|
||||||
def pop(self, key, default=None):
|
|
||||||
# We gently pretend we're a Python 3 mappingproxy.
|
|
||||||
raise AttributeError(
|
|
||||||
"'mappingproxy' object has no attribute 'pop'"
|
|
||||||
)
|
|
||||||
|
|
||||||
def popitem(self):
|
|
||||||
# We gently pretend we're a Python 3 mappingproxy.
|
|
||||||
raise AttributeError(
|
|
||||||
"'mappingproxy' object has no attribute 'popitem'"
|
|
||||||
)
|
|
||||||
|
|
||||||
def setdefault(self, key, default=None):
|
|
||||||
# We gently pretend we're a Python 3 mappingproxy.
|
|
||||||
raise AttributeError(
|
|
||||||
"'mappingproxy' object has no attribute 'setdefault'"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
# Override to be identical to the Python 3 version.
|
|
||||||
return "mappingproxy(" + repr(self.data) + ")"
|
|
||||||
|
|
||||||
def metadata_proxy(d):
|
|
||||||
res = ReadOnlyDict()
|
|
||||||
res.data.update(d) # We blocked update, so we have to do it like this.
|
|
||||||
return res
|
|
||||||
|
|
||||||
def just_warn(*args, **kw): # pragma: no cover
|
|
||||||
"""
|
|
||||||
We only warn on Python 3 because we are not aware of any concrete
|
|
||||||
consequences of not setting the cell on Python 2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
else: # Python 3 and later.
|
|
||||||
from collections.abc import Mapping, Sequence # noqa
|
|
||||||
|
|
||||||
def just_warn(*args, **kw):
|
|
||||||
"""
|
|
||||||
We only warn on Python 3 because we are not aware of any concrete
|
|
||||||
consequences of not setting the cell on Python 2.
|
|
||||||
"""
|
|
||||||
warnings.warn(
|
|
||||||
"Running interpreter doesn't sufficiently support code object "
|
|
||||||
"introspection. Some features like bare super() or accessing "
|
|
||||||
"__class__ will not work with slotted classes.",
|
|
||||||
RuntimeWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
def isclass(klass):
|
|
||||||
return isinstance(klass, type)
|
|
||||||
|
|
||||||
TYPE = "class"
|
|
||||||
|
|
||||||
def iteritems(d):
|
|
||||||
return d.items()
|
|
||||||
|
|
||||||
new_class = types.new_class
|
|
||||||
|
|
||||||
def metadata_proxy(d):
|
|
||||||
return types.MappingProxyType(dict(d))
|
|
||||||
|
|
||||||
|
|
||||||
def make_set_closure_cell():
|
|
||||||
"""Return a function of two arguments (cell, value) which sets
|
|
||||||
the value stored in the closure cell `cell` to `value`.
|
|
||||||
"""
|
|
||||||
# pypy makes this easy. (It also supports the logic below, but
|
|
||||||
# why not do the easy/fast thing?)
|
|
||||||
if PYPY:
|
|
||||||
|
|
||||||
def set_closure_cell(cell, value):
|
|
||||||
cell.__setstate__((value,))
|
|
||||||
|
|
||||||
return set_closure_cell
|
|
||||||
|
|
||||||
# Otherwise gotta do it the hard way.
|
|
||||||
|
|
||||||
# Create a function that will set its first cellvar to `value`.
|
|
||||||
def set_first_cellvar_to(value):
|
|
||||||
x = value
|
|
||||||
return
|
|
||||||
|
|
||||||
# This function will be eliminated as dead code, but
|
|
||||||
# not before its reference to `x` forces `x` to be
|
|
||||||
# represented as a closure cell rather than a local.
|
|
||||||
def force_x_to_be_a_cell(): # pragma: no cover
|
|
||||||
return x
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Extract the code object and make sure our assumptions about
|
|
||||||
# the closure behavior are correct.
|
|
||||||
if PY2:
|
|
||||||
co = set_first_cellvar_to.func_code
|
|
||||||
else:
|
|
||||||
co = set_first_cellvar_to.__code__
|
|
||||||
if co.co_cellvars != ("x",) or co.co_freevars != ():
|
|
||||||
raise AssertionError # pragma: no cover
|
|
||||||
|
|
||||||
# Convert this code object to a code object that sets the
|
|
||||||
# function's first _freevar_ (not cellvar) to the argument.
|
|
||||||
if sys.version_info >= (3, 8):
|
|
||||||
# CPython 3.8+ has an incompatible CodeType signature
|
|
||||||
# (added a posonlyargcount argument) but also added
|
|
||||||
# CodeType.replace() to do this without counting parameters.
|
|
||||||
set_first_freevar_code = co.replace(
|
|
||||||
co_cellvars=co.co_freevars, co_freevars=co.co_cellvars
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
args = [co.co_argcount]
|
|
||||||
if not PY2:
|
|
||||||
args.append(co.co_kwonlyargcount)
|
|
||||||
args.extend(
|
|
||||||
[
|
|
||||||
co.co_nlocals,
|
|
||||||
co.co_stacksize,
|
|
||||||
co.co_flags,
|
|
||||||
co.co_code,
|
|
||||||
co.co_consts,
|
|
||||||
co.co_names,
|
|
||||||
co.co_varnames,
|
|
||||||
co.co_filename,
|
|
||||||
co.co_name,
|
|
||||||
co.co_firstlineno,
|
|
||||||
co.co_lnotab,
|
|
||||||
# These two arguments are reversed:
|
|
||||||
co.co_cellvars,
|
|
||||||
co.co_freevars,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
set_first_freevar_code = types.CodeType(*args)
|
|
||||||
|
|
||||||
def set_closure_cell(cell, value):
|
|
||||||
# Create a function using the set_first_freevar_code,
|
|
||||||
# whose first closure cell is `cell`. Calling it will
|
|
||||||
# change the value of that cell.
|
|
||||||
setter = types.FunctionType(
|
|
||||||
set_first_freevar_code, {}, "setter", (), (cell,)
|
|
||||||
)
|
|
||||||
# And call it to set the cell.
|
|
||||||
setter(value)
|
|
||||||
|
|
||||||
# Make sure it works on this interpreter:
|
|
||||||
def make_func_with_cell():
|
|
||||||
x = None
|
|
||||||
|
|
||||||
def func():
|
|
||||||
return x # pragma: no cover
|
|
||||||
|
|
||||||
return func
|
|
||||||
|
|
||||||
if PY2:
|
|
||||||
cell = make_func_with_cell().func_closure[0]
|
|
||||||
else:
|
|
||||||
cell = make_func_with_cell().__closure__[0]
|
|
||||||
set_closure_cell(cell, 100)
|
|
||||||
if cell.cell_contents != 100:
|
|
||||||
raise AssertionError # pragma: no cover
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return just_warn
|
|
||||||
else:
|
|
||||||
return set_closure_cell
|
|
||||||
|
|
||||||
|
|
||||||
set_closure_cell = make_set_closure_cell()
|
|
||||||
|
|
||||||
# Thread-local global to track attrs instances which are already being repr'd.
|
|
||||||
# This is needed because there is no other (thread-safe) way to pass info
|
|
||||||
# about the instances that are already being repr'd through the call stack
|
|
||||||
# in order to ensure we don't perform infinite recursion.
|
|
||||||
#
|
|
||||||
# For instance, if an instance contains a dict which contains that instance,
|
|
||||||
# we need to know that we're already repr'ing the outside instance from within
|
|
||||||
# the dict's repr() call.
|
|
||||||
#
|
|
||||||
# This lives here rather than in _make.py so that the functions in _make.py
|
|
||||||
# don't have a direct reference to the thread-local in their globals dict.
|
|
||||||
# If they have such a reference, it breaks cloudpickle.
|
|
||||||
repr_context = threading.local()
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["set_run_validators", "get_run_validators"]
|
|
||||||
|
|
||||||
_run_validators = True
|
|
||||||
|
|
||||||
|
|
||||||
def set_run_validators(run):
|
|
||||||
"""
|
|
||||||
Set whether or not validators are run. By default, they are run.
|
|
||||||
|
|
||||||
.. deprecated:: 21.3.0 It will not be removed, but it also will not be
|
|
||||||
moved to new ``attrs`` namespace. Use `attrs.validators.set_disabled()`
|
|
||||||
instead.
|
|
||||||
"""
|
|
||||||
if not isinstance(run, bool):
|
|
||||||
raise TypeError("'run' must be bool.")
|
|
||||||
global _run_validators
|
|
||||||
_run_validators = run
|
|
||||||
|
|
||||||
|
|
||||||
def get_run_validators():
|
|
||||||
"""
|
|
||||||
Return whether or not validators are run.
|
|
||||||
|
|
||||||
.. deprecated:: 21.3.0 It will not be removed, but it also will not be
|
|
||||||
moved to new ``attrs`` namespace. Use `attrs.validators.get_disabled()`
|
|
||||||
instead.
|
|
||||||
"""
|
|
||||||
return _run_validators
|
|
||||||
@@ -1,422 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import copy
|
|
||||||
|
|
||||||
from ._compat import iteritems
|
|
||||||
from ._make import NOTHING, _obj_setattr, fields
|
|
||||||
from .exceptions import AttrsAttributeNotFoundError
|
|
||||||
|
|
||||||
|
|
||||||
def asdict(
|
|
||||||
inst,
|
|
||||||
recurse=True,
|
|
||||||
filter=None,
|
|
||||||
dict_factory=dict,
|
|
||||||
retain_collection_types=False,
|
|
||||||
value_serializer=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Return the ``attrs`` attribute values of *inst* as a dict.
|
|
||||||
|
|
||||||
Optionally recurse into other ``attrs``-decorated classes.
|
|
||||||
|
|
||||||
:param inst: Instance of an ``attrs``-decorated class.
|
|
||||||
:param bool recurse: Recurse into classes that are also
|
|
||||||
``attrs``-decorated.
|
|
||||||
:param callable filter: A callable whose return code determines whether an
|
|
||||||
attribute or element is included (``True``) or dropped (``False``). Is
|
|
||||||
called with the `attrs.Attribute` as the first argument and the
|
|
||||||
value as the second argument.
|
|
||||||
:param callable dict_factory: A callable to produce dictionaries from. For
|
|
||||||
example, to produce ordered dictionaries instead of normal Python
|
|
||||||
dictionaries, pass in ``collections.OrderedDict``.
|
|
||||||
:param bool retain_collection_types: Do not convert to ``list`` when
|
|
||||||
encountering an attribute whose type is ``tuple`` or ``set``. Only
|
|
||||||
meaningful if ``recurse`` is ``True``.
|
|
||||||
:param Optional[callable] value_serializer: A hook that is called for every
|
|
||||||
attribute or dict key/value. It receives the current instance, field
|
|
||||||
and value and must return the (updated) value. The hook is run *after*
|
|
||||||
the optional *filter* has been applied.
|
|
||||||
|
|
||||||
:rtype: return type of *dict_factory*
|
|
||||||
|
|
||||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
|
||||||
class.
|
|
||||||
|
|
||||||
.. versionadded:: 16.0.0 *dict_factory*
|
|
||||||
.. versionadded:: 16.1.0 *retain_collection_types*
|
|
||||||
.. versionadded:: 20.3.0 *value_serializer*
|
|
||||||
.. versionadded:: 21.3.0 If a dict has a collection for a key, it is
|
|
||||||
serialized as a tuple.
|
|
||||||
"""
|
|
||||||
attrs = fields(inst.__class__)
|
|
||||||
rv = dict_factory()
|
|
||||||
for a in attrs:
|
|
||||||
v = getattr(inst, a.name)
|
|
||||||
if filter is not None and not filter(a, v):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if value_serializer is not None:
|
|
||||||
v = value_serializer(inst, a, v)
|
|
||||||
|
|
||||||
if recurse is True:
|
|
||||||
if has(v.__class__):
|
|
||||||
rv[a.name] = asdict(
|
|
||||||
v,
|
|
||||||
recurse=True,
|
|
||||||
filter=filter,
|
|
||||||
dict_factory=dict_factory,
|
|
||||||
retain_collection_types=retain_collection_types,
|
|
||||||
value_serializer=value_serializer,
|
|
||||||
)
|
|
||||||
elif isinstance(v, (tuple, list, set, frozenset)):
|
|
||||||
cf = v.__class__ if retain_collection_types is True else list
|
|
||||||
rv[a.name] = cf(
|
|
||||||
[
|
|
||||||
_asdict_anything(
|
|
||||||
i,
|
|
||||||
is_key=False,
|
|
||||||
filter=filter,
|
|
||||||
dict_factory=dict_factory,
|
|
||||||
retain_collection_types=retain_collection_types,
|
|
||||||
value_serializer=value_serializer,
|
|
||||||
)
|
|
||||||
for i in v
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
df = dict_factory
|
|
||||||
rv[a.name] = df(
|
|
||||||
(
|
|
||||||
_asdict_anything(
|
|
||||||
kk,
|
|
||||||
is_key=True,
|
|
||||||
filter=filter,
|
|
||||||
dict_factory=df,
|
|
||||||
retain_collection_types=retain_collection_types,
|
|
||||||
value_serializer=value_serializer,
|
|
||||||
),
|
|
||||||
_asdict_anything(
|
|
||||||
vv,
|
|
||||||
is_key=False,
|
|
||||||
filter=filter,
|
|
||||||
dict_factory=df,
|
|
||||||
retain_collection_types=retain_collection_types,
|
|
||||||
value_serializer=value_serializer,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for kk, vv in iteritems(v)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rv[a.name] = v
|
|
||||||
else:
|
|
||||||
rv[a.name] = v
|
|
||||||
return rv
|
|
||||||
|
|
||||||
|
|
||||||
def _asdict_anything(
|
|
||||||
val,
|
|
||||||
is_key,
|
|
||||||
filter,
|
|
||||||
dict_factory,
|
|
||||||
retain_collection_types,
|
|
||||||
value_serializer,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
``asdict`` only works on attrs instances, this works on anything.
|
|
||||||
"""
|
|
||||||
if getattr(val.__class__, "__attrs_attrs__", None) is not None:
|
|
||||||
# Attrs class.
|
|
||||||
rv = asdict(
|
|
||||||
val,
|
|
||||||
recurse=True,
|
|
||||||
filter=filter,
|
|
||||||
dict_factory=dict_factory,
|
|
||||||
retain_collection_types=retain_collection_types,
|
|
||||||
value_serializer=value_serializer,
|
|
||||||
)
|
|
||||||
elif isinstance(val, (tuple, list, set, frozenset)):
|
|
||||||
if retain_collection_types is True:
|
|
||||||
cf = val.__class__
|
|
||||||
elif is_key:
|
|
||||||
cf = tuple
|
|
||||||
else:
|
|
||||||
cf = list
|
|
||||||
|
|
||||||
rv = cf(
|
|
||||||
[
|
|
||||||
_asdict_anything(
|
|
||||||
i,
|
|
||||||
is_key=False,
|
|
||||||
filter=filter,
|
|
||||||
dict_factory=dict_factory,
|
|
||||||
retain_collection_types=retain_collection_types,
|
|
||||||
value_serializer=value_serializer,
|
|
||||||
)
|
|
||||||
for i in val
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif isinstance(val, dict):
|
|
||||||
df = dict_factory
|
|
||||||
rv = df(
|
|
||||||
(
|
|
||||||
_asdict_anything(
|
|
||||||
kk,
|
|
||||||
is_key=True,
|
|
||||||
filter=filter,
|
|
||||||
dict_factory=df,
|
|
||||||
retain_collection_types=retain_collection_types,
|
|
||||||
value_serializer=value_serializer,
|
|
||||||
),
|
|
||||||
_asdict_anything(
|
|
||||||
vv,
|
|
||||||
is_key=False,
|
|
||||||
filter=filter,
|
|
||||||
dict_factory=df,
|
|
||||||
retain_collection_types=retain_collection_types,
|
|
||||||
value_serializer=value_serializer,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for kk, vv in iteritems(val)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rv = val
|
|
||||||
if value_serializer is not None:
|
|
||||||
rv = value_serializer(None, None, rv)
|
|
||||||
|
|
||||||
return rv
|
|
||||||
|
|
||||||
|
|
||||||
def astuple(
|
|
||||||
inst,
|
|
||||||
recurse=True,
|
|
||||||
filter=None,
|
|
||||||
tuple_factory=tuple,
|
|
||||||
retain_collection_types=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Return the ``attrs`` attribute values of *inst* as a tuple.
|
|
||||||
|
|
||||||
Optionally recurse into other ``attrs``-decorated classes.
|
|
||||||
|
|
||||||
:param inst: Instance of an ``attrs``-decorated class.
|
|
||||||
:param bool recurse: Recurse into classes that are also
|
|
||||||
``attrs``-decorated.
|
|
||||||
:param callable filter: A callable whose return code determines whether an
|
|
||||||
attribute or element is included (``True``) or dropped (``False``). Is
|
|
||||||
called with the `attrs.Attribute` as the first argument and the
|
|
||||||
value as the second argument.
|
|
||||||
:param callable tuple_factory: A callable to produce tuples from. For
|
|
||||||
example, to produce lists instead of tuples.
|
|
||||||
:param bool retain_collection_types: Do not convert to ``list``
|
|
||||||
or ``dict`` when encountering an attribute which type is
|
|
||||||
``tuple``, ``dict`` or ``set``. Only meaningful if ``recurse`` is
|
|
||||||
``True``.
|
|
||||||
|
|
||||||
:rtype: return type of *tuple_factory*
|
|
||||||
|
|
||||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
|
||||||
class.
|
|
||||||
|
|
||||||
.. versionadded:: 16.2.0
|
|
||||||
"""
|
|
||||||
attrs = fields(inst.__class__)
|
|
||||||
rv = []
|
|
||||||
retain = retain_collection_types # Very long. :/
|
|
||||||
for a in attrs:
|
|
||||||
v = getattr(inst, a.name)
|
|
||||||
if filter is not None and not filter(a, v):
|
|
||||||
continue
|
|
||||||
if recurse is True:
|
|
||||||
if has(v.__class__):
|
|
||||||
rv.append(
|
|
||||||
astuple(
|
|
||||||
v,
|
|
||||||
recurse=True,
|
|
||||||
filter=filter,
|
|
||||||
tuple_factory=tuple_factory,
|
|
||||||
retain_collection_types=retain,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(v, (tuple, list, set, frozenset)):
|
|
||||||
cf = v.__class__ if retain is True else list
|
|
||||||
rv.append(
|
|
||||||
cf(
|
|
||||||
[
|
|
||||||
astuple(
|
|
||||||
j,
|
|
||||||
recurse=True,
|
|
||||||
filter=filter,
|
|
||||||
tuple_factory=tuple_factory,
|
|
||||||
retain_collection_types=retain,
|
|
||||||
)
|
|
||||||
if has(j.__class__)
|
|
||||||
else j
|
|
||||||
for j in v
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
df = v.__class__ if retain is True else dict
|
|
||||||
rv.append(
|
|
||||||
df(
|
|
||||||
(
|
|
||||||
astuple(
|
|
||||||
kk,
|
|
||||||
tuple_factory=tuple_factory,
|
|
||||||
retain_collection_types=retain,
|
|
||||||
)
|
|
||||||
if has(kk.__class__)
|
|
||||||
else kk,
|
|
||||||
astuple(
|
|
||||||
vv,
|
|
||||||
tuple_factory=tuple_factory,
|
|
||||||
retain_collection_types=retain,
|
|
||||||
)
|
|
||||||
if has(vv.__class__)
|
|
||||||
else vv,
|
|
||||||
)
|
|
||||||
for kk, vv in iteritems(v)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rv.append(v)
|
|
||||||
else:
|
|
||||||
rv.append(v)
|
|
||||||
|
|
||||||
return rv if tuple_factory is list else tuple_factory(rv)
|
|
||||||
|
|
||||||
|
|
||||||
def has(cls):
|
|
||||||
"""
|
|
||||||
Check whether *cls* is a class with ``attrs`` attributes.
|
|
||||||
|
|
||||||
:param type cls: Class to introspect.
|
|
||||||
:raise TypeError: If *cls* is not a class.
|
|
||||||
|
|
||||||
:rtype: bool
|
|
||||||
"""
|
|
||||||
return getattr(cls, "__attrs_attrs__", None) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def assoc(inst, **changes):
|
|
||||||
"""
|
|
||||||
Copy *inst* and apply *changes*.
|
|
||||||
|
|
||||||
:param inst: Instance of a class with ``attrs`` attributes.
|
|
||||||
:param changes: Keyword changes in the new copy.
|
|
||||||
|
|
||||||
:return: A copy of inst with *changes* incorporated.
|
|
||||||
|
|
||||||
:raise attr.exceptions.AttrsAttributeNotFoundError: If *attr_name* couldn't
|
|
||||||
be found on *cls*.
|
|
||||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
|
||||||
class.
|
|
||||||
|
|
||||||
.. deprecated:: 17.1.0
|
|
||||||
Use `attrs.evolve` instead if you can.
|
|
||||||
This function will not be removed du to the slightly different approach
|
|
||||||
compared to `attrs.evolve`.
|
|
||||||
"""
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"assoc is deprecated and will be removed after 2018/01.",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
new = copy.copy(inst)
|
|
||||||
attrs = fields(inst.__class__)
|
|
||||||
for k, v in iteritems(changes):
|
|
||||||
a = getattr(attrs, k, NOTHING)
|
|
||||||
if a is NOTHING:
|
|
||||||
raise AttrsAttributeNotFoundError(
|
|
||||||
"{k} is not an attrs attribute on {cl}.".format(
|
|
||||||
k=k, cl=new.__class__
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_obj_setattr(new, k, v)
|
|
||||||
return new
|
|
||||||
|
|
||||||
|
|
||||||
def evolve(inst, **changes):
|
|
||||||
"""
|
|
||||||
Create a new instance, based on *inst* with *changes* applied.
|
|
||||||
|
|
||||||
:param inst: Instance of a class with ``attrs`` attributes.
|
|
||||||
:param changes: Keyword changes in the new copy.
|
|
||||||
|
|
||||||
:return: A copy of inst with *changes* incorporated.
|
|
||||||
|
|
||||||
:raise TypeError: If *attr_name* couldn't be found in the class
|
|
||||||
``__init__``.
|
|
||||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
|
||||||
class.
|
|
||||||
|
|
||||||
.. versionadded:: 17.1.0
|
|
||||||
"""
|
|
||||||
cls = inst.__class__
|
|
||||||
attrs = fields(cls)
|
|
||||||
for a in attrs:
|
|
||||||
if not a.init:
|
|
||||||
continue
|
|
||||||
attr_name = a.name # To deal with private attributes.
|
|
||||||
init_name = attr_name if attr_name[0] != "_" else attr_name[1:]
|
|
||||||
if init_name not in changes:
|
|
||||||
changes[init_name] = getattr(inst, attr_name)
|
|
||||||
|
|
||||||
return cls(**changes)
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_types(cls, globalns=None, localns=None, attribs=None):
|
|
||||||
"""
|
|
||||||
Resolve any strings and forward annotations in type annotations.
|
|
||||||
|
|
||||||
This is only required if you need concrete types in `Attribute`'s *type*
|
|
||||||
field. In other words, you don't need to resolve your types if you only
|
|
||||||
use them for static type checking.
|
|
||||||
|
|
||||||
With no arguments, names will be looked up in the module in which the class
|
|
||||||
was created. If this is not what you want, e.g. if the name only exists
|
|
||||||
inside a method, you may pass *globalns* or *localns* to specify other
|
|
||||||
dictionaries in which to look up these names. See the docs of
|
|
||||||
`typing.get_type_hints` for more details.
|
|
||||||
|
|
||||||
:param type cls: Class to resolve.
|
|
||||||
:param Optional[dict] globalns: Dictionary containing global variables.
|
|
||||||
:param Optional[dict] localns: Dictionary containing local variables.
|
|
||||||
:param Optional[list] attribs: List of attribs for the given class.
|
|
||||||
This is necessary when calling from inside a ``field_transformer``
|
|
||||||
since *cls* is not an ``attrs`` class yet.
|
|
||||||
|
|
||||||
:raise TypeError: If *cls* is not a class.
|
|
||||||
:raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs``
|
|
||||||
class and you didn't pass any attribs.
|
|
||||||
:raise NameError: If types cannot be resolved because of missing variables.
|
|
||||||
|
|
||||||
:returns: *cls* so you can use this function also as a class decorator.
|
|
||||||
Please note that you have to apply it **after** `attrs.define`. That
|
|
||||||
means the decorator has to come in the line **before** `attrs.define`.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
.. versionadded:: 21.1.0 *attribs*
|
|
||||||
|
|
||||||
"""
|
|
||||||
# Since calling get_type_hints is expensive we cache whether we've
|
|
||||||
# done it already.
|
|
||||||
if getattr(cls, "__attrs_types_resolved__", None) != cls:
|
|
||||||
import typing
|
|
||||||
|
|
||||||
hints = typing.get_type_hints(cls, globalns=globalns, localns=localns)
|
|
||||||
for field in fields(cls) if attribs is None else attribs:
|
|
||||||
if field.name in hints:
|
|
||||||
# Since fields have been frozen we must work around it.
|
|
||||||
_obj_setattr(field, "type", hints[field.name])
|
|
||||||
# We store the class we resolved so that subclasses know they haven't
|
|
||||||
# been resolved.
|
|
||||||
cls.__attrs_types_resolved__ = cls
|
|
||||||
|
|
||||||
# Return the class so you can use it as a decorator too.
|
|
||||||
return cls
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,216 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
"""
|
|
||||||
These are Python 3.6+-only and keyword-only APIs that call `attr.s` and
|
|
||||||
`attr.ib` with different default values.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from . import setters
|
|
||||||
from ._funcs import asdict as _asdict
|
|
||||||
from ._funcs import astuple as _astuple
|
|
||||||
from ._make import (
|
|
||||||
NOTHING,
|
|
||||||
_frozen_setattrs,
|
|
||||||
_ng_default_on_setattr,
|
|
||||||
attrib,
|
|
||||||
attrs,
|
|
||||||
)
|
|
||||||
from .exceptions import UnannotatedAttributeError
|
|
||||||
|
|
||||||
|
|
||||||
def define(
|
|
||||||
maybe_cls=None,
|
|
||||||
*,
|
|
||||||
these=None,
|
|
||||||
repr=None,
|
|
||||||
hash=None,
|
|
||||||
init=None,
|
|
||||||
slots=True,
|
|
||||||
frozen=False,
|
|
||||||
weakref_slot=True,
|
|
||||||
str=False,
|
|
||||||
auto_attribs=None,
|
|
||||||
kw_only=False,
|
|
||||||
cache_hash=False,
|
|
||||||
auto_exc=True,
|
|
||||||
eq=None,
|
|
||||||
order=False,
|
|
||||||
auto_detect=True,
|
|
||||||
getstate_setstate=None,
|
|
||||||
on_setattr=None,
|
|
||||||
field_transformer=None,
|
|
||||||
match_args=True,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Define an ``attrs`` class.
|
|
||||||
|
|
||||||
Differences to the classic `attr.s` that it uses underneath:
|
|
||||||
|
|
||||||
- Automatically detect whether or not *auto_attribs* should be `True`
|
|
||||||
(c.f. *auto_attribs* parameter).
|
|
||||||
- If *frozen* is `False`, run converters and validators when setting an
|
|
||||||
attribute by default.
|
|
||||||
- *slots=True* (see :term:`slotted classes` for potentially surprising
|
|
||||||
behaviors)
|
|
||||||
- *auto_exc=True*
|
|
||||||
- *auto_detect=True*
|
|
||||||
- *order=False*
|
|
||||||
- *match_args=True*
|
|
||||||
- Some options that were only relevant on Python 2 or were kept around for
|
|
||||||
backwards-compatibility have been removed.
|
|
||||||
|
|
||||||
Please note that these are all defaults and you can change them as you
|
|
||||||
wish.
|
|
||||||
|
|
||||||
:param Optional[bool] auto_attribs: If set to `True` or `False`, it behaves
|
|
||||||
exactly like `attr.s`. If left `None`, `attr.s` will try to guess:
|
|
||||||
|
|
||||||
1. If any attributes are annotated and no unannotated `attrs.fields`\ s
|
|
||||||
are found, it assumes *auto_attribs=True*.
|
|
||||||
2. Otherwise it assumes *auto_attribs=False* and tries to collect
|
|
||||||
`attrs.fields`\ s.
|
|
||||||
|
|
||||||
For now, please refer to `attr.s` for the rest of the parameters.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
.. versionchanged:: 21.3.0 Converters are also run ``on_setattr``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def do_it(cls, auto_attribs):
|
|
||||||
return attrs(
|
|
||||||
maybe_cls=cls,
|
|
||||||
these=these,
|
|
||||||
repr=repr,
|
|
||||||
hash=hash,
|
|
||||||
init=init,
|
|
||||||
slots=slots,
|
|
||||||
frozen=frozen,
|
|
||||||
weakref_slot=weakref_slot,
|
|
||||||
str=str,
|
|
||||||
auto_attribs=auto_attribs,
|
|
||||||
kw_only=kw_only,
|
|
||||||
cache_hash=cache_hash,
|
|
||||||
auto_exc=auto_exc,
|
|
||||||
eq=eq,
|
|
||||||
order=order,
|
|
||||||
auto_detect=auto_detect,
|
|
||||||
collect_by_mro=True,
|
|
||||||
getstate_setstate=getstate_setstate,
|
|
||||||
on_setattr=on_setattr,
|
|
||||||
field_transformer=field_transformer,
|
|
||||||
match_args=match_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
def wrap(cls):
|
|
||||||
"""
|
|
||||||
Making this a wrapper ensures this code runs during class creation.
|
|
||||||
|
|
||||||
We also ensure that frozen-ness of classes is inherited.
|
|
||||||
"""
|
|
||||||
nonlocal frozen, on_setattr
|
|
||||||
|
|
||||||
had_on_setattr = on_setattr not in (None, setters.NO_OP)
|
|
||||||
|
|
||||||
# By default, mutable classes convert & validate on setattr.
|
|
||||||
if frozen is False and on_setattr is None:
|
|
||||||
on_setattr = _ng_default_on_setattr
|
|
||||||
|
|
||||||
# However, if we subclass a frozen class, we inherit the immutability
|
|
||||||
# and disable on_setattr.
|
|
||||||
for base_cls in cls.__bases__:
|
|
||||||
if base_cls.__setattr__ is _frozen_setattrs:
|
|
||||||
if had_on_setattr:
|
|
||||||
raise ValueError(
|
|
||||||
"Frozen classes can't use on_setattr "
|
|
||||||
"(frozen-ness was inherited)."
|
|
||||||
)
|
|
||||||
|
|
||||||
on_setattr = setters.NO_OP
|
|
||||||
break
|
|
||||||
|
|
||||||
if auto_attribs is not None:
|
|
||||||
return do_it(cls, auto_attribs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return do_it(cls, True)
|
|
||||||
except UnannotatedAttributeError:
|
|
||||||
return do_it(cls, False)
|
|
||||||
|
|
||||||
# maybe_cls's type depends on the usage of the decorator. It's a class
|
|
||||||
# if it's used as `@attrs` but ``None`` if used as `@attrs()`.
|
|
||||||
if maybe_cls is None:
|
|
||||||
return wrap
|
|
||||||
else:
|
|
||||||
return wrap(maybe_cls)
|
|
||||||
|
|
||||||
|
|
||||||
mutable = define
|
|
||||||
frozen = partial(define, frozen=True, on_setattr=None)
|
|
||||||
|
|
||||||
|
|
||||||
def field(
|
|
||||||
*,
|
|
||||||
default=NOTHING,
|
|
||||||
validator=None,
|
|
||||||
repr=True,
|
|
||||||
hash=None,
|
|
||||||
init=True,
|
|
||||||
metadata=None,
|
|
||||||
converter=None,
|
|
||||||
factory=None,
|
|
||||||
kw_only=False,
|
|
||||||
eq=None,
|
|
||||||
order=None,
|
|
||||||
on_setattr=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Identical to `attr.ib`, except keyword-only and with some arguments
|
|
||||||
removed.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
"""
|
|
||||||
return attrib(
|
|
||||||
default=default,
|
|
||||||
validator=validator,
|
|
||||||
repr=repr,
|
|
||||||
hash=hash,
|
|
||||||
init=init,
|
|
||||||
metadata=metadata,
|
|
||||||
converter=converter,
|
|
||||||
factory=factory,
|
|
||||||
kw_only=kw_only,
|
|
||||||
eq=eq,
|
|
||||||
order=order,
|
|
||||||
on_setattr=on_setattr,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def asdict(inst, *, recurse=True, filter=None, value_serializer=None):
|
|
||||||
"""
|
|
||||||
Same as `attr.asdict`, except that collections types are always retained
|
|
||||||
and dict is always used as *dict_factory*.
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
return _asdict(
|
|
||||||
inst=inst,
|
|
||||||
recurse=recurse,
|
|
||||||
filter=filter,
|
|
||||||
value_serializer=value_serializer,
|
|
||||||
retain_collection_types=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def astuple(inst, *, recurse=True, filter=None):
|
|
||||||
"""
|
|
||||||
Same as `attr.astuple`, except that collections types are always retained
|
|
||||||
and `tuple` is always used as the *tuple_factory*.
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
return _astuple(
|
|
||||||
inst=inst, recurse=recurse, filter=filter, retain_collection_types=True
|
|
||||||
)
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
from functools import total_ordering
|
|
||||||
|
|
||||||
from ._funcs import astuple
|
|
||||||
from ._make import attrib, attrs
|
|
||||||
|
|
||||||
|
|
||||||
@total_ordering
|
|
||||||
@attrs(eq=False, order=False, slots=True, frozen=True)
|
|
||||||
class VersionInfo(object):
|
|
||||||
"""
|
|
||||||
A version object that can be compared to tuple of length 1--4:
|
|
||||||
|
|
||||||
>>> attr.VersionInfo(19, 1, 0, "final") <= (19, 2)
|
|
||||||
True
|
|
||||||
>>> attr.VersionInfo(19, 1, 0, "final") < (19, 1, 1)
|
|
||||||
True
|
|
||||||
>>> vi = attr.VersionInfo(19, 2, 0, "final")
|
|
||||||
>>> vi < (19, 1, 1)
|
|
||||||
False
|
|
||||||
>>> vi < (19,)
|
|
||||||
False
|
|
||||||
>>> vi == (19, 2,)
|
|
||||||
True
|
|
||||||
>>> vi == (19, 2, 1)
|
|
||||||
False
|
|
||||||
|
|
||||||
.. versionadded:: 19.2
|
|
||||||
"""
|
|
||||||
|
|
||||||
year = attrib(type=int)
|
|
||||||
minor = attrib(type=int)
|
|
||||||
micro = attrib(type=int)
|
|
||||||
releaselevel = attrib(type=str)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _from_version_string(cls, s):
|
|
||||||
"""
|
|
||||||
Parse *s* and return a _VersionInfo.
|
|
||||||
"""
|
|
||||||
v = s.split(".")
|
|
||||||
if len(v) == 3:
|
|
||||||
v.append("final")
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
year=int(v[0]), minor=int(v[1]), micro=int(v[2]), releaselevel=v[3]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _ensure_tuple(self, other):
|
|
||||||
"""
|
|
||||||
Ensure *other* is a tuple of a valid length.
|
|
||||||
|
|
||||||
Returns a possibly transformed *other* and ourselves as a tuple of
|
|
||||||
the same length as *other*.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.__class__ is other.__class__:
|
|
||||||
other = astuple(other)
|
|
||||||
|
|
||||||
if not isinstance(other, tuple):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
if not (1 <= len(other) <= 4):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
return astuple(self)[: len(other)], other
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
try:
|
|
||||||
us, them = self._ensure_tuple(other)
|
|
||||||
except NotImplementedError:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return us == them
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
try:
|
|
||||||
us, them = self._ensure_tuple(other)
|
|
||||||
except NotImplementedError:
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
# Since alphabetically "dev0" < "final" < "post1" < "post2", we don't
|
|
||||||
# have to do anything special with releaselevel for now.
|
|
||||||
return us < them
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
"""
|
|
||||||
Commonly useful converters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
from ._compat import PY2
|
|
||||||
from ._make import NOTHING, Factory, pipe
|
|
||||||
|
|
||||||
|
|
||||||
if not PY2:
|
|
||||||
import inspect
|
|
||||||
import typing
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"default_if_none",
|
|
||||||
"optional",
|
|
||||||
"pipe",
|
|
||||||
"to_bool",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def optional(converter):
|
|
||||||
"""
|
|
||||||
A converter that allows an attribute to be optional. An optional attribute
|
|
||||||
is one which can be set to ``None``.
|
|
||||||
|
|
||||||
Type annotations will be inferred from the wrapped converter's, if it
|
|
||||||
has any.
|
|
||||||
|
|
||||||
:param callable converter: the converter that is used for non-``None``
|
|
||||||
values.
|
|
||||||
|
|
||||||
.. versionadded:: 17.1.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
def optional_converter(val):
|
|
||||||
if val is None:
|
|
||||||
return None
|
|
||||||
return converter(val)
|
|
||||||
|
|
||||||
if not PY2:
|
|
||||||
sig = None
|
|
||||||
try:
|
|
||||||
sig = inspect.signature(converter)
|
|
||||||
except (ValueError, TypeError): # inspect failed
|
|
||||||
pass
|
|
||||||
if sig:
|
|
||||||
params = list(sig.parameters.values())
|
|
||||||
if params and params[0].annotation is not inspect.Parameter.empty:
|
|
||||||
optional_converter.__annotations__["val"] = typing.Optional[
|
|
||||||
params[0].annotation
|
|
||||||
]
|
|
||||||
if sig.return_annotation is not inspect.Signature.empty:
|
|
||||||
optional_converter.__annotations__["return"] = typing.Optional[
|
|
||||||
sig.return_annotation
|
|
||||||
]
|
|
||||||
|
|
||||||
return optional_converter
|
|
||||||
|
|
||||||
|
|
||||||
def default_if_none(default=NOTHING, factory=None):
|
|
||||||
"""
|
|
||||||
A converter that allows to replace ``None`` values by *default* or the
|
|
||||||
result of *factory*.
|
|
||||||
|
|
||||||
:param default: Value to be used if ``None`` is passed. Passing an instance
|
|
||||||
of `attrs.Factory` is supported, however the ``takes_self`` option
|
|
||||||
is *not*.
|
|
||||||
:param callable factory: A callable that takes no parameters whose result
|
|
||||||
is used if ``None`` is passed.
|
|
||||||
|
|
||||||
:raises TypeError: If **neither** *default* or *factory* is passed.
|
|
||||||
:raises TypeError: If **both** *default* and *factory* are passed.
|
|
||||||
:raises ValueError: If an instance of `attrs.Factory` is passed with
|
|
||||||
``takes_self=True``.
|
|
||||||
|
|
||||||
.. versionadded:: 18.2.0
|
|
||||||
"""
|
|
||||||
if default is NOTHING and factory is None:
|
|
||||||
raise TypeError("Must pass either `default` or `factory`.")
|
|
||||||
|
|
||||||
if default is not NOTHING and factory is not None:
|
|
||||||
raise TypeError(
|
|
||||||
"Must pass either `default` or `factory` but not both."
|
|
||||||
)
|
|
||||||
|
|
||||||
if factory is not None:
|
|
||||||
default = Factory(factory)
|
|
||||||
|
|
||||||
if isinstance(default, Factory):
|
|
||||||
if default.takes_self:
|
|
||||||
raise ValueError(
|
|
||||||
"`takes_self` is not supported by default_if_none."
|
|
||||||
)
|
|
||||||
|
|
||||||
def default_if_none_converter(val):
|
|
||||||
if val is not None:
|
|
||||||
return val
|
|
||||||
|
|
||||||
return default.factory()
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def default_if_none_converter(val):
|
|
||||||
if val is not None:
|
|
||||||
return val
|
|
||||||
|
|
||||||
return default
|
|
||||||
|
|
||||||
return default_if_none_converter
|
|
||||||
|
|
||||||
|
|
||||||
def to_bool(val):
|
|
||||||
"""
|
|
||||||
Convert "boolean" strings (e.g., from env. vars.) to real booleans.
|
|
||||||
|
|
||||||
Values mapping to :code:`True`:
|
|
||||||
|
|
||||||
- :code:`True`
|
|
||||||
- :code:`"true"` / :code:`"t"`
|
|
||||||
- :code:`"yes"` / :code:`"y"`
|
|
||||||
- :code:`"on"`
|
|
||||||
- :code:`"1"`
|
|
||||||
- :code:`1`
|
|
||||||
|
|
||||||
Values mapping to :code:`False`:
|
|
||||||
|
|
||||||
- :code:`False`
|
|
||||||
- :code:`"false"` / :code:`"f"`
|
|
||||||
- :code:`"no"` / :code:`"n"`
|
|
||||||
- :code:`"off"`
|
|
||||||
- :code:`"0"`
|
|
||||||
- :code:`0`
|
|
||||||
|
|
||||||
:raises ValueError: for any other value.
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
if isinstance(val, str):
|
|
||||||
val = val.lower()
|
|
||||||
truthy = {True, "true", "t", "yes", "y", "on", "1", 1}
|
|
||||||
falsy = {False, "false", "f", "no", "n", "off", "0", 0}
|
|
||||||
try:
|
|
||||||
if val in truthy:
|
|
||||||
return True
|
|
||||||
if val in falsy:
|
|
||||||
return False
|
|
||||||
except TypeError:
|
|
||||||
# Raised when "val" is not hashable (e.g., lists)
|
|
||||||
pass
|
|
||||||
raise ValueError("Cannot convert value to bool: {}".format(val))
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenError(AttributeError):
|
|
||||||
"""
|
|
||||||
A frozen/immutable instance or attribute have been attempted to be
|
|
||||||
modified.
|
|
||||||
|
|
||||||
It mirrors the behavior of ``namedtuples`` by using the same error message
|
|
||||||
and subclassing `AttributeError`.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
msg = "can't set attribute"
|
|
||||||
args = [msg]
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenInstanceError(FrozenError):
|
|
||||||
"""
|
|
||||||
A frozen instance has been attempted to be modified.
|
|
||||||
|
|
||||||
.. versionadded:: 16.1.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenAttributeError(FrozenError):
|
|
||||||
"""
|
|
||||||
A frozen attribute has been attempted to be modified.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class AttrsAttributeNotFoundError(ValueError):
|
|
||||||
"""
|
|
||||||
An ``attrs`` function couldn't find an attribute that the user asked for.
|
|
||||||
|
|
||||||
.. versionadded:: 16.2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class NotAnAttrsClassError(ValueError):
|
|
||||||
"""
|
|
||||||
A non-``attrs`` class has been passed into an ``attrs`` function.
|
|
||||||
|
|
||||||
.. versionadded:: 16.2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultAlreadySetError(RuntimeError):
|
|
||||||
"""
|
|
||||||
A default has been set using ``attr.ib()`` and is attempted to be reset
|
|
||||||
using the decorator.
|
|
||||||
|
|
||||||
.. versionadded:: 17.1.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class UnannotatedAttributeError(RuntimeError):
|
|
||||||
"""
|
|
||||||
A class with ``auto_attribs=True`` has an ``attr.ib()`` without a type
|
|
||||||
annotation.
|
|
||||||
|
|
||||||
.. versionadded:: 17.3.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class PythonTooOldError(RuntimeError):
|
|
||||||
"""
|
|
||||||
It was attempted to use an ``attrs`` feature that requires a newer Python
|
|
||||||
version.
|
|
||||||
|
|
||||||
.. versionadded:: 18.2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class NotCallableError(TypeError):
|
|
||||||
"""
|
|
||||||
A ``attr.ib()`` requiring a callable has been set with a value
|
|
||||||
that is not callable.
|
|
||||||
|
|
||||||
.. versionadded:: 19.2.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, msg, value):
|
|
||||||
super(TypeError, self).__init__(msg, value)
|
|
||||||
self.msg = msg
|
|
||||||
self.value = value
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return str(self.msg)
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
"""
|
|
||||||
Commonly useful filters for `attr.asdict`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
from ._compat import isclass
|
|
||||||
from ._make import Attribute
|
|
||||||
|
|
||||||
|
|
||||||
def _split_what(what):
|
|
||||||
"""
|
|
||||||
Returns a tuple of `frozenset`s of classes and attributes.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
frozenset(cls for cls in what if isclass(cls)),
|
|
||||||
frozenset(cls for cls in what if isinstance(cls, Attribute)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def include(*what):
|
|
||||||
"""
|
|
||||||
Include *what*.
|
|
||||||
|
|
||||||
:param what: What to include.
|
|
||||||
:type what: `list` of `type` or `attrs.Attribute`\\ s
|
|
||||||
|
|
||||||
:rtype: `callable`
|
|
||||||
"""
|
|
||||||
cls, attrs = _split_what(what)
|
|
||||||
|
|
||||||
def include_(attribute, value):
|
|
||||||
return value.__class__ in cls or attribute in attrs
|
|
||||||
|
|
||||||
return include_
|
|
||||||
|
|
||||||
|
|
||||||
def exclude(*what):
|
|
||||||
"""
|
|
||||||
Exclude *what*.
|
|
||||||
|
|
||||||
:param what: What to exclude.
|
|
||||||
:type what: `list` of classes or `attrs.Attribute`\\ s.
|
|
||||||
|
|
||||||
:rtype: `callable`
|
|
||||||
"""
|
|
||||||
cls, attrs = _split_what(what)
|
|
||||||
|
|
||||||
def exclude_(attribute, value):
|
|
||||||
return value.__class__ not in cls and attribute not in attrs
|
|
||||||
|
|
||||||
return exclude_
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
"""
|
|
||||||
Commonly used hooks for on_setattr.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
from . import _config
|
|
||||||
from .exceptions import FrozenAttributeError
|
|
||||||
|
|
||||||
|
|
||||||
def pipe(*setters):
|
|
||||||
"""
|
|
||||||
Run all *setters* and return the return value of the last one.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrapped_pipe(instance, attrib, new_value):
|
|
||||||
rv = new_value
|
|
||||||
|
|
||||||
for setter in setters:
|
|
||||||
rv = setter(instance, attrib, rv)
|
|
||||||
|
|
||||||
return rv
|
|
||||||
|
|
||||||
return wrapped_pipe
|
|
||||||
|
|
||||||
|
|
||||||
def frozen(_, __, ___):
|
|
||||||
"""
|
|
||||||
Prevent an attribute to be modified.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
"""
|
|
||||||
raise FrozenAttributeError()
|
|
||||||
|
|
||||||
|
|
||||||
def validate(instance, attrib, new_value):
|
|
||||||
"""
|
|
||||||
Run *attrib*'s validator on *new_value* if it has one.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
"""
|
|
||||||
if _config._run_validators is False:
|
|
||||||
return new_value
|
|
||||||
|
|
||||||
v = attrib.validator
|
|
||||||
if not v:
|
|
||||||
return new_value
|
|
||||||
|
|
||||||
v(instance, attrib, new_value)
|
|
||||||
|
|
||||||
return new_value
|
|
||||||
|
|
||||||
|
|
||||||
def convert(instance, attrib, new_value):
|
|
||||||
"""
|
|
||||||
Run *attrib*'s converter -- if it has one -- on *new_value* and return the
|
|
||||||
result.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
"""
|
|
||||||
c = attrib.converter
|
|
||||||
if c:
|
|
||||||
return c(new_value)
|
|
||||||
|
|
||||||
return new_value
|
|
||||||
|
|
||||||
|
|
||||||
NO_OP = object()
|
|
||||||
"""
|
|
||||||
Sentinel for disabling class-wide *on_setattr* hooks for certain attributes.
|
|
||||||
|
|
||||||
Does not work in `pipe` or within lists.
|
|
||||||
|
|
||||||
.. versionadded:: 20.1.0
|
|
||||||
"""
|
|
||||||
@@ -1,561 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
"""
|
|
||||||
Commonly useful validators.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import operator
|
|
||||||
import re
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
from ._config import get_run_validators, set_run_validators
|
|
||||||
from ._make import _AndValidator, and_, attrib, attrs
|
|
||||||
from .exceptions import NotCallableError
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
Pattern = re.Pattern
|
|
||||||
except AttributeError: # Python <3.7 lacks a Pattern type.
|
|
||||||
Pattern = type(re.compile(""))
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"and_",
|
|
||||||
"deep_iterable",
|
|
||||||
"deep_mapping",
|
|
||||||
"disabled",
|
|
||||||
"ge",
|
|
||||||
"get_disabled",
|
|
||||||
"gt",
|
|
||||||
"in_",
|
|
||||||
"instance_of",
|
|
||||||
"is_callable",
|
|
||||||
"le",
|
|
||||||
"lt",
|
|
||||||
"matches_re",
|
|
||||||
"max_len",
|
|
||||||
"optional",
|
|
||||||
"provides",
|
|
||||||
"set_disabled",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def set_disabled(disabled):
|
|
||||||
"""
|
|
||||||
Globally disable or enable running validators.
|
|
||||||
|
|
||||||
By default, they are run.
|
|
||||||
|
|
||||||
:param disabled: If ``True``, disable running all validators.
|
|
||||||
:type disabled: bool
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
|
|
||||||
This function is not thread-safe!
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
set_run_validators(not disabled)
|
|
||||||
|
|
||||||
|
|
||||||
def get_disabled():
|
|
||||||
"""
|
|
||||||
Return a bool indicating whether validators are currently disabled or not.
|
|
||||||
|
|
||||||
:return: ``True`` if validators are currently disabled.
|
|
||||||
:rtype: bool
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
return not get_run_validators()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def disabled():
|
|
||||||
"""
|
|
||||||
Context manager that disables running validators within its context.
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
|
|
||||||
This context manager is not thread-safe!
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
set_run_validators(False)
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
set_run_validators(True)
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, slots=True, hash=True)
|
|
||||||
class _InstanceOfValidator(object):
|
|
||||||
type = attrib()
|
|
||||||
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
"""
|
|
||||||
We use a callable class to be able to change the ``__repr__``.
|
|
||||||
"""
|
|
||||||
if not isinstance(value, self.type):
|
|
||||||
raise TypeError(
|
|
||||||
"'{name}' must be {type!r} (got {value!r} that is a "
|
|
||||||
"{actual!r}).".format(
|
|
||||||
name=attr.name,
|
|
||||||
type=self.type,
|
|
||||||
actual=value.__class__,
|
|
||||||
value=value,
|
|
||||||
),
|
|
||||||
attr,
|
|
||||||
self.type,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<instance_of validator for type {type!r}>".format(
|
|
||||||
type=self.type
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def instance_of(type):
|
|
||||||
"""
|
|
||||||
A validator that raises a `TypeError` if the initializer is called
|
|
||||||
with a wrong type for this particular attribute (checks are performed using
|
|
||||||
`isinstance` therefore it's also valid to pass a tuple of types).
|
|
||||||
|
|
||||||
:param type: The type to check for.
|
|
||||||
:type type: type or tuple of types
|
|
||||||
|
|
||||||
:raises TypeError: With a human readable error message, the attribute
|
|
||||||
(of type `attrs.Attribute`), the expected type, and the value it
|
|
||||||
got.
|
|
||||||
"""
|
|
||||||
return _InstanceOfValidator(type)
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, frozen=True, slots=True)
|
|
||||||
class _MatchesReValidator(object):
|
|
||||||
pattern = attrib()
|
|
||||||
match_func = attrib()
|
|
||||||
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
"""
|
|
||||||
We use a callable class to be able to change the ``__repr__``.
|
|
||||||
"""
|
|
||||||
if not self.match_func(value):
|
|
||||||
raise ValueError(
|
|
||||||
"'{name}' must match regex {pattern!r}"
|
|
||||||
" ({value!r} doesn't)".format(
|
|
||||||
name=attr.name, pattern=self.pattern.pattern, value=value
|
|
||||||
),
|
|
||||||
attr,
|
|
||||||
self.pattern,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<matches_re validator for pattern {pattern!r}>".format(
|
|
||||||
pattern=self.pattern
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def matches_re(regex, flags=0, func=None):
|
|
||||||
r"""
|
|
||||||
A validator that raises `ValueError` if the initializer is called
|
|
||||||
with a string that doesn't match *regex*.
|
|
||||||
|
|
||||||
:param regex: a regex string or precompiled pattern to match against
|
|
||||||
:param int flags: flags that will be passed to the underlying re function
|
|
||||||
(default 0)
|
|
||||||
:param callable func: which underlying `re` function to call (options
|
|
||||||
are `re.fullmatch`, `re.search`, `re.match`, default
|
|
||||||
is ``None`` which means either `re.fullmatch` or an emulation of
|
|
||||||
it on Python 2). For performance reasons, they won't be used directly
|
|
||||||
but on a pre-`re.compile`\ ed pattern.
|
|
||||||
|
|
||||||
.. versionadded:: 19.2.0
|
|
||||||
.. versionchanged:: 21.3.0 *regex* can be a pre-compiled pattern.
|
|
||||||
"""
|
|
||||||
fullmatch = getattr(re, "fullmatch", None)
|
|
||||||
valid_funcs = (fullmatch, None, re.search, re.match)
|
|
||||||
if func not in valid_funcs:
|
|
||||||
raise ValueError(
|
|
||||||
"'func' must be one of {}.".format(
|
|
||||||
", ".join(
|
|
||||||
sorted(
|
|
||||||
e and e.__name__ or "None" for e in set(valid_funcs)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(regex, Pattern):
|
|
||||||
if flags:
|
|
||||||
raise TypeError(
|
|
||||||
"'flags' can only be used with a string pattern; "
|
|
||||||
"pass flags to re.compile() instead"
|
|
||||||
)
|
|
||||||
pattern = regex
|
|
||||||
else:
|
|
||||||
pattern = re.compile(regex, flags)
|
|
||||||
|
|
||||||
if func is re.match:
|
|
||||||
match_func = pattern.match
|
|
||||||
elif func is re.search:
|
|
||||||
match_func = pattern.search
|
|
||||||
elif fullmatch:
|
|
||||||
match_func = pattern.fullmatch
|
|
||||||
else: # Python 2 fullmatch emulation (https://bugs.python.org/issue16203)
|
|
||||||
pattern = re.compile(
|
|
||||||
r"(?:{})\Z".format(pattern.pattern), pattern.flags
|
|
||||||
)
|
|
||||||
match_func = pattern.match
|
|
||||||
|
|
||||||
return _MatchesReValidator(pattern, match_func)
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, slots=True, hash=True)
|
|
||||||
class _ProvidesValidator(object):
|
|
||||||
interface = attrib()
|
|
||||||
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
"""
|
|
||||||
We use a callable class to be able to change the ``__repr__``.
|
|
||||||
"""
|
|
||||||
if not self.interface.providedBy(value):
|
|
||||||
raise TypeError(
|
|
||||||
"'{name}' must provide {interface!r} which {value!r} "
|
|
||||||
"doesn't.".format(
|
|
||||||
name=attr.name, interface=self.interface, value=value
|
|
||||||
),
|
|
||||||
attr,
|
|
||||||
self.interface,
|
|
||||||
value,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<provides validator for interface {interface!r}>".format(
|
|
||||||
interface=self.interface
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def provides(interface):
|
|
||||||
"""
|
|
||||||
A validator that raises a `TypeError` if the initializer is called
|
|
||||||
with an object that does not provide the requested *interface* (checks are
|
|
||||||
performed using ``interface.providedBy(value)`` (see `zope.interface
|
|
||||||
<https://zopeinterface.readthedocs.io/en/latest/>`_).
|
|
||||||
|
|
||||||
:param interface: The interface to check for.
|
|
||||||
:type interface: ``zope.interface.Interface``
|
|
||||||
|
|
||||||
:raises TypeError: With a human readable error message, the attribute
|
|
||||||
(of type `attrs.Attribute`), the expected interface, and the
|
|
||||||
value it got.
|
|
||||||
"""
|
|
||||||
return _ProvidesValidator(interface)
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, slots=True, hash=True)
|
|
||||||
class _OptionalValidator(object):
|
|
||||||
validator = attrib()
|
|
||||||
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
if value is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.validator(inst, attr, value)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<optional validator for {what} or None>".format(
|
|
||||||
what=repr(self.validator)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def optional(validator):
|
|
||||||
"""
|
|
||||||
A validator that makes an attribute optional. An optional attribute is one
|
|
||||||
which can be set to ``None`` in addition to satisfying the requirements of
|
|
||||||
the sub-validator.
|
|
||||||
|
|
||||||
:param validator: A validator (or a list of validators) that is used for
|
|
||||||
non-``None`` values.
|
|
||||||
:type validator: callable or `list` of callables.
|
|
||||||
|
|
||||||
.. versionadded:: 15.1.0
|
|
||||||
.. versionchanged:: 17.1.0 *validator* can be a list of validators.
|
|
||||||
"""
|
|
||||||
if isinstance(validator, list):
|
|
||||||
return _OptionalValidator(_AndValidator(validator))
|
|
||||||
return _OptionalValidator(validator)
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, slots=True, hash=True)
|
|
||||||
class _InValidator(object):
|
|
||||||
options = attrib()
|
|
||||||
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
try:
|
|
||||||
in_options = value in self.options
|
|
||||||
except TypeError: # e.g. `1 in "abc"`
|
|
||||||
in_options = False
|
|
||||||
|
|
||||||
if not in_options:
|
|
||||||
raise ValueError(
|
|
||||||
"'{name}' must be in {options!r} (got {value!r})".format(
|
|
||||||
name=attr.name, options=self.options, value=value
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<in_ validator with options {options!r}>".format(
|
|
||||||
options=self.options
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def in_(options):
|
|
||||||
"""
|
|
||||||
A validator that raises a `ValueError` if the initializer is called
|
|
||||||
with a value that does not belong in the options provided. The check is
|
|
||||||
performed using ``value in options``.
|
|
||||||
|
|
||||||
:param options: Allowed options.
|
|
||||||
:type options: list, tuple, `enum.Enum`, ...
|
|
||||||
|
|
||||||
:raises ValueError: With a human readable error message, the attribute (of
|
|
||||||
type `attrs.Attribute`), the expected options, and the value it
|
|
||||||
got.
|
|
||||||
|
|
||||||
.. versionadded:: 17.1.0
|
|
||||||
"""
|
|
||||||
return _InValidator(options)
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, slots=False, hash=True)
|
|
||||||
class _IsCallableValidator(object):
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
"""
|
|
||||||
We use a callable class to be able to change the ``__repr__``.
|
|
||||||
"""
|
|
||||||
if not callable(value):
|
|
||||||
message = (
|
|
||||||
"'{name}' must be callable "
|
|
||||||
"(got {value!r} that is a {actual!r})."
|
|
||||||
)
|
|
||||||
raise NotCallableError(
|
|
||||||
msg=message.format(
|
|
||||||
name=attr.name, value=value, actual=value.__class__
|
|
||||||
),
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<is_callable validator>"
|
|
||||||
|
|
||||||
|
|
||||||
def is_callable():
|
|
||||||
"""
|
|
||||||
A validator that raises a `attr.exceptions.NotCallableError` if the
|
|
||||||
initializer is called with a value for this particular attribute
|
|
||||||
that is not callable.
|
|
||||||
|
|
||||||
.. versionadded:: 19.1.0
|
|
||||||
|
|
||||||
:raises `attr.exceptions.NotCallableError`: With a human readable error
|
|
||||||
message containing the attribute (`attrs.Attribute`) name,
|
|
||||||
and the value it got.
|
|
||||||
"""
|
|
||||||
return _IsCallableValidator()
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, slots=True, hash=True)
|
|
||||||
class _DeepIterable(object):
|
|
||||||
member_validator = attrib(validator=is_callable())
|
|
||||||
iterable_validator = attrib(
|
|
||||||
default=None, validator=optional(is_callable())
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
"""
|
|
||||||
We use a callable class to be able to change the ``__repr__``.
|
|
||||||
"""
|
|
||||||
if self.iterable_validator is not None:
|
|
||||||
self.iterable_validator(inst, attr, value)
|
|
||||||
|
|
||||||
for member in value:
|
|
||||||
self.member_validator(inst, attr, member)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
iterable_identifier = (
|
|
||||||
""
|
|
||||||
if self.iterable_validator is None
|
|
||||||
else " {iterable!r}".format(iterable=self.iterable_validator)
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
"<deep_iterable validator for{iterable_identifier}"
|
|
||||||
" iterables of {member!r}>"
|
|
||||||
).format(
|
|
||||||
iterable_identifier=iterable_identifier,
|
|
||||||
member=self.member_validator,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def deep_iterable(member_validator, iterable_validator=None):
|
|
||||||
"""
|
|
||||||
A validator that performs deep validation of an iterable.
|
|
||||||
|
|
||||||
:param member_validator: Validator to apply to iterable members
|
|
||||||
:param iterable_validator: Validator to apply to iterable itself
|
|
||||||
(optional)
|
|
||||||
|
|
||||||
.. versionadded:: 19.1.0
|
|
||||||
|
|
||||||
:raises TypeError: if any sub-validators fail
|
|
||||||
"""
|
|
||||||
return _DeepIterable(member_validator, iterable_validator)
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, slots=True, hash=True)
|
|
||||||
class _DeepMapping(object):
|
|
||||||
key_validator = attrib(validator=is_callable())
|
|
||||||
value_validator = attrib(validator=is_callable())
|
|
||||||
mapping_validator = attrib(default=None, validator=optional(is_callable()))
|
|
||||||
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
"""
|
|
||||||
We use a callable class to be able to change the ``__repr__``.
|
|
||||||
"""
|
|
||||||
if self.mapping_validator is not None:
|
|
||||||
self.mapping_validator(inst, attr, value)
|
|
||||||
|
|
||||||
for key in value:
|
|
||||||
self.key_validator(inst, attr, key)
|
|
||||||
self.value_validator(inst, attr, value[key])
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
"<deep_mapping validator for objects mapping {key!r} to {value!r}>"
|
|
||||||
).format(key=self.key_validator, value=self.value_validator)
|
|
||||||
|
|
||||||
|
|
||||||
def deep_mapping(key_validator, value_validator, mapping_validator=None):
|
|
||||||
"""
|
|
||||||
A validator that performs deep validation of a dictionary.
|
|
||||||
|
|
||||||
:param key_validator: Validator to apply to dictionary keys
|
|
||||||
:param value_validator: Validator to apply to dictionary values
|
|
||||||
:param mapping_validator: Validator to apply to top-level mapping
|
|
||||||
attribute (optional)
|
|
||||||
|
|
||||||
.. versionadded:: 19.1.0
|
|
||||||
|
|
||||||
:raises TypeError: if any sub-validators fail
|
|
||||||
"""
|
|
||||||
return _DeepMapping(key_validator, value_validator, mapping_validator)
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, frozen=True, slots=True)
|
|
||||||
class _NumberValidator(object):
|
|
||||||
bound = attrib()
|
|
||||||
compare_op = attrib()
|
|
||||||
compare_func = attrib()
|
|
||||||
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
"""
|
|
||||||
We use a callable class to be able to change the ``__repr__``.
|
|
||||||
"""
|
|
||||||
if not self.compare_func(value, self.bound):
|
|
||||||
raise ValueError(
|
|
||||||
"'{name}' must be {op} {bound}: {value}".format(
|
|
||||||
name=attr.name,
|
|
||||||
op=self.compare_op,
|
|
||||||
bound=self.bound,
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<Validator for x {op} {bound}>".format(
|
|
||||||
op=self.compare_op, bound=self.bound
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def lt(val):
|
|
||||||
"""
|
|
||||||
A validator that raises `ValueError` if the initializer is called
|
|
||||||
with a number larger or equal to *val*.
|
|
||||||
|
|
||||||
:param val: Exclusive upper bound for values
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
return _NumberValidator(val, "<", operator.lt)
|
|
||||||
|
|
||||||
|
|
||||||
def le(val):
|
|
||||||
"""
|
|
||||||
A validator that raises `ValueError` if the initializer is called
|
|
||||||
with a number greater than *val*.
|
|
||||||
|
|
||||||
:param val: Inclusive upper bound for values
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
return _NumberValidator(val, "<=", operator.le)
|
|
||||||
|
|
||||||
|
|
||||||
def ge(val):
|
|
||||||
"""
|
|
||||||
A validator that raises `ValueError` if the initializer is called
|
|
||||||
with a number smaller than *val*.
|
|
||||||
|
|
||||||
:param val: Inclusive lower bound for values
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
return _NumberValidator(val, ">=", operator.ge)
|
|
||||||
|
|
||||||
|
|
||||||
def gt(val):
|
|
||||||
"""
|
|
||||||
A validator that raises `ValueError` if the initializer is called
|
|
||||||
with a number smaller or equal to *val*.
|
|
||||||
|
|
||||||
:param val: Exclusive lower bound for values
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
return _NumberValidator(val, ">", operator.gt)
|
|
||||||
|
|
||||||
|
|
||||||
@attrs(repr=False, frozen=True, slots=True)
|
|
||||||
class _MaxLengthValidator(object):
|
|
||||||
max_length = attrib()
|
|
||||||
|
|
||||||
def __call__(self, inst, attr, value):
|
|
||||||
"""
|
|
||||||
We use a callable class to be able to change the ``__repr__``.
|
|
||||||
"""
|
|
||||||
if len(value) > self.max_length:
|
|
||||||
raise ValueError(
|
|
||||||
"Length of '{name}' must be <= {max}: {len}".format(
|
|
||||||
name=attr.name, max=self.max_length, len=len(value)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "<max_len validator for {max}>".format(max=self.max_length)
|
|
||||||
|
|
||||||
|
|
||||||
def max_len(length):
|
|
||||||
"""
|
|
||||||
A validator that raises `ValueError` if the initializer is called
|
|
||||||
with a string or iterable that is longer than *length*.
|
|
||||||
|
|
||||||
:param int length: Maximum length of the string or iterable
|
|
||||||
|
|
||||||
.. versionadded:: 21.3.0
|
|
||||||
"""
|
|
||||||
return _MaxLengthValidator(length)
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from attr import (
|
|
||||||
NOTHING,
|
|
||||||
Attribute,
|
|
||||||
Factory,
|
|
||||||
__author__,
|
|
||||||
__copyright__,
|
|
||||||
__description__,
|
|
||||||
__doc__,
|
|
||||||
__email__,
|
|
||||||
__license__,
|
|
||||||
__title__,
|
|
||||||
__url__,
|
|
||||||
__version__,
|
|
||||||
__version_info__,
|
|
||||||
assoc,
|
|
||||||
cmp_using,
|
|
||||||
define,
|
|
||||||
evolve,
|
|
||||||
field,
|
|
||||||
fields,
|
|
||||||
fields_dict,
|
|
||||||
frozen,
|
|
||||||
has,
|
|
||||||
make_class,
|
|
||||||
mutable,
|
|
||||||
resolve_types,
|
|
||||||
validate,
|
|
||||||
)
|
|
||||||
from attr._next_gen import asdict, astuple
|
|
||||||
|
|
||||||
from . import converters, exceptions, filters, setters, validators
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"__author__",
|
|
||||||
"__copyright__",
|
|
||||||
"__description__",
|
|
||||||
"__doc__",
|
|
||||||
"__email__",
|
|
||||||
"__license__",
|
|
||||||
"__title__",
|
|
||||||
"__url__",
|
|
||||||
"__version__",
|
|
||||||
"__version_info__",
|
|
||||||
"asdict",
|
|
||||||
"assoc",
|
|
||||||
"astuple",
|
|
||||||
"Attribute",
|
|
||||||
"cmp_using",
|
|
||||||
"converters",
|
|
||||||
"define",
|
|
||||||
"evolve",
|
|
||||||
"exceptions",
|
|
||||||
"Factory",
|
|
||||||
"field",
|
|
||||||
"fields_dict",
|
|
||||||
"fields",
|
|
||||||
"filters",
|
|
||||||
"frozen",
|
|
||||||
"has",
|
|
||||||
"make_class",
|
|
||||||
"mutable",
|
|
||||||
"NOTHING",
|
|
||||||
"resolve_types",
|
|
||||||
"setters",
|
|
||||||
"validate",
|
|
||||||
"validators",
|
|
||||||
]
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from attr.converters import * # noqa
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from attr.exceptions import * # noqa
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from attr.filters import * # noqa
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from attr.setters import * # noqa
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from attr.validators import * # noqa
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
from .core import contents, where
|
|
||||||
|
|
||||||
__all__ = ["contents", "where"]
|
|
||||||
__version__ = "2022.05.18"
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
from certifi import contents, where
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("-c", "--contents", action="store_true")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.contents:
|
|
||||||
print(contents())
|
|
||||||
else:
|
|
||||||
print(where())
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
"""
|
|
||||||
certifi.py
|
|
||||||
~~~~~~~~~~
|
|
||||||
|
|
||||||
This module returns the installation location of cacert.pem or its contents.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import types
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
try:
|
|
||||||
from importlib.resources import path as get_path, read_text
|
|
||||||
|
|
||||||
_CACERT_CTX = None
|
|
||||||
_CACERT_PATH = None
|
|
||||||
|
|
||||||
def where() -> str:
|
|
||||||
# This is slightly terrible, but we want to delay extracting the file
|
|
||||||
# in cases where we're inside of a zipimport situation until someone
|
|
||||||
# actually calls where(), but we don't want to re-extract the file
|
|
||||||
# on every call of where(), so we'll do it once then store it in a
|
|
||||||
# global variable.
|
|
||||||
global _CACERT_CTX
|
|
||||||
global _CACERT_PATH
|
|
||||||
if _CACERT_PATH is None:
|
|
||||||
# This is slightly janky, the importlib.resources API wants you to
|
|
||||||
# manage the cleanup of this file, so it doesn't actually return a
|
|
||||||
# path, it returns a context manager that will give you the path
|
|
||||||
# when you enter it and will do any cleanup when you leave it. In
|
|
||||||
# the common case of not needing a temporary file, it will just
|
|
||||||
# return the file system location and the __exit__() is a no-op.
|
|
||||||
#
|
|
||||||
# We also have to hold onto the actual context manager, because
|
|
||||||
# it will do the cleanup whenever it gets garbage collected, so
|
|
||||||
# we will also store that at the global level as well.
|
|
||||||
_CACERT_CTX = get_path("certifi", "cacert.pem")
|
|
||||||
_CACERT_PATH = str(_CACERT_CTX.__enter__())
|
|
||||||
|
|
||||||
return _CACERT_PATH
|
|
||||||
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
Package = Union[types.ModuleType, str]
|
|
||||||
Resource = Union[str, "os.PathLike"]
|
|
||||||
|
|
||||||
# This fallback will work for Python versions prior to 3.7 that lack the
|
|
||||||
# importlib.resources module but relies on the existing `where` function
|
|
||||||
# so won't address issues with environments like PyOxidizer that don't set
|
|
||||||
# __file__ on modules.
|
|
||||||
def read_text(
|
|
||||||
package: Package,
|
|
||||||
resource: Resource,
|
|
||||||
encoding: str = 'utf-8',
|
|
||||||
errors: str = 'strict'
|
|
||||||
) -> str:
|
|
||||||
with open(where(), "r", encoding=encoding) as data:
|
|
||||||
return data.read()
|
|
||||||
|
|
||||||
# If we don't have importlib.resources, then we will just do the old logic
|
|
||||||
# of assuming we're on the filesystem and munge the path directly.
|
|
||||||
def where() -> str:
|
|
||||||
f = os.path.dirname(__file__)
|
|
||||||
|
|
||||||
return os.path.join(f, "cacert.pem")
|
|
||||||
|
|
||||||
|
|
||||||
def contents() -> str:
|
|
||||||
return read_text("certifi", "cacert.pem", encoding="ascii")
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
__all__ = ['FFI', 'VerificationError', 'VerificationMissing', 'CDefError',
|
|
||||||
'FFIError']
|
|
||||||
|
|
||||||
from .api import FFI
|
|
||||||
from .error import CDefError, FFIError, VerificationError, VerificationMissing
|
|
||||||
from .error import PkgConfigError
|
|
||||||
|
|
||||||
__version__ = "1.15.0"
|
|
||||||
__version_info__ = (1, 15, 0)
|
|
||||||
|
|
||||||
# The verifier module file names are based on the CRC32 of a string that
|
|
||||||
# contains the following version number. It may be older than __version__
|
|
||||||
# if nothing is clearly incompatible.
|
|
||||||
__version_verifier_modules__ = "0.8.6"
|
|
||||||
@@ -1,965 +0,0 @@
|
|||||||
import sys, types
|
|
||||||
from .lock import allocate_lock
|
|
||||||
from .error import CDefError
|
|
||||||
from . import model
|
|
||||||
|
|
||||||
try:
|
|
||||||
callable
|
|
||||||
except NameError:
|
|
||||||
# Python 3.1
|
|
||||||
from collections import Callable
|
|
||||||
callable = lambda x: isinstance(x, Callable)
|
|
||||||
|
|
||||||
try:
|
|
||||||
basestring
|
|
||||||
except NameError:
|
|
||||||
# Python 3.x
|
|
||||||
basestring = str
|
|
||||||
|
|
||||||
_unspecified = object()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FFI(object):
|
|
||||||
r'''
|
|
||||||
The main top-level class that you instantiate once, or once per module.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
|
|
||||||
ffi = FFI()
|
|
||||||
ffi.cdef("""
|
|
||||||
int printf(const char *, ...);
|
|
||||||
""")
|
|
||||||
|
|
||||||
C = ffi.dlopen(None) # standard library
|
|
||||||
-or-
|
|
||||||
C = ffi.verify() # use a C compiler: verify the decl above is right
|
|
||||||
|
|
||||||
C.printf("hello, %s!\n", ffi.new("char[]", "world"))
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, backend=None):
|
|
||||||
"""Create an FFI instance. The 'backend' argument is used to
|
|
||||||
select a non-default backend, mostly for tests.
|
|
||||||
"""
|
|
||||||
if backend is None:
|
|
||||||
# You need PyPy (>= 2.0 beta), or a CPython (>= 2.6) with
|
|
||||||
# _cffi_backend.so compiled.
|
|
||||||
import _cffi_backend as backend
|
|
||||||
from . import __version__
|
|
||||||
if backend.__version__ != __version__:
|
|
||||||
# bad version! Try to be as explicit as possible.
|
|
||||||
if hasattr(backend, '__file__'):
|
|
||||||
# CPython
|
|
||||||
raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. When we import the top-level '_cffi_backend' extension module, we get version %s, located in %r. The two versions should be equal; check your installation." % (
|
|
||||||
__version__, __file__,
|
|
||||||
backend.__version__, backend.__file__))
|
|
||||||
else:
|
|
||||||
# PyPy
|
|
||||||
raise Exception("Version mismatch: this is the 'cffi' package version %s, located in %r. This interpreter comes with a built-in '_cffi_backend' module, which is version %s. The two versions should be equal; check your installation." % (
|
|
||||||
__version__, __file__, backend.__version__))
|
|
||||||
# (If you insist you can also try to pass the option
|
|
||||||
# 'backend=backend_ctypes.CTypesBackend()', but don't
|
|
||||||
# rely on it! It's probably not going to work well.)
|
|
||||||
|
|
||||||
from . import cparser
|
|
||||||
self._backend = backend
|
|
||||||
self._lock = allocate_lock()
|
|
||||||
self._parser = cparser.Parser()
|
|
||||||
self._cached_btypes = {}
|
|
||||||
self._parsed_types = types.ModuleType('parsed_types').__dict__
|
|
||||||
self._new_types = types.ModuleType('new_types').__dict__
|
|
||||||
self._function_caches = []
|
|
||||||
self._libraries = []
|
|
||||||
self._cdefsources = []
|
|
||||||
self._included_ffis = []
|
|
||||||
self._windows_unicode = None
|
|
||||||
self._init_once_cache = {}
|
|
||||||
self._cdef_version = None
|
|
||||||
self._embedding = None
|
|
||||||
self._typecache = model.get_typecache(backend)
|
|
||||||
if hasattr(backend, 'set_ffi'):
|
|
||||||
backend.set_ffi(self)
|
|
||||||
for name in list(backend.__dict__):
|
|
||||||
if name.startswith('RTLD_'):
|
|
||||||
setattr(self, name, getattr(backend, name))
|
|
||||||
#
|
|
||||||
with self._lock:
|
|
||||||
self.BVoidP = self._get_cached_btype(model.voidp_type)
|
|
||||||
self.BCharA = self._get_cached_btype(model.char_array_type)
|
|
||||||
if isinstance(backend, types.ModuleType):
|
|
||||||
# _cffi_backend: attach these constants to the class
|
|
||||||
if not hasattr(FFI, 'NULL'):
|
|
||||||
FFI.NULL = self.cast(self.BVoidP, 0)
|
|
||||||
FFI.CData, FFI.CType = backend._get_types()
|
|
||||||
else:
|
|
||||||
# ctypes backend: attach these constants to the instance
|
|
||||||
self.NULL = self.cast(self.BVoidP, 0)
|
|
||||||
self.CData, self.CType = backend._get_types()
|
|
||||||
self.buffer = backend.buffer
|
|
||||||
|
|
||||||
def cdef(self, csource, override=False, packed=False, pack=None):
|
|
||||||
"""Parse the given C source. This registers all declared functions,
|
|
||||||
types, and global variables. The functions and global variables can
|
|
||||||
then be accessed via either 'ffi.dlopen()' or 'ffi.verify()'.
|
|
||||||
The types can be used in 'ffi.new()' and other functions.
|
|
||||||
If 'packed' is specified as True, all structs declared inside this
|
|
||||||
cdef are packed, i.e. laid out without any field alignment at all.
|
|
||||||
Alternatively, 'pack' can be a small integer, and requests for
|
|
||||||
alignment greater than that are ignored (pack=1 is equivalent to
|
|
||||||
packed=True).
|
|
||||||
"""
|
|
||||||
self._cdef(csource, override=override, packed=packed, pack=pack)
|
|
||||||
|
|
||||||
def embedding_api(self, csource, packed=False, pack=None):
|
|
||||||
self._cdef(csource, packed=packed, pack=pack, dllexport=True)
|
|
||||||
if self._embedding is None:
|
|
||||||
self._embedding = ''
|
|
||||||
|
|
||||||
def _cdef(self, csource, override=False, **options):
|
|
||||||
if not isinstance(csource, str): # unicode, on Python 2
|
|
||||||
if not isinstance(csource, basestring):
|
|
||||||
raise TypeError("cdef() argument must be a string")
|
|
||||||
csource = csource.encode('ascii')
|
|
||||||
with self._lock:
|
|
||||||
self._cdef_version = object()
|
|
||||||
self._parser.parse(csource, override=override, **options)
|
|
||||||
self._cdefsources.append(csource)
|
|
||||||
if override:
|
|
||||||
for cache in self._function_caches:
|
|
||||||
cache.clear()
|
|
||||||
finishlist = self._parser._recomplete
|
|
||||||
if finishlist:
|
|
||||||
self._parser._recomplete = []
|
|
||||||
for tp in finishlist:
|
|
||||||
tp.finish_backend_type(self, finishlist)
|
|
||||||
|
|
||||||
def dlopen(self, name, flags=0):
|
|
||||||
"""Load and return a dynamic library identified by 'name'.
|
|
||||||
The standard C library can be loaded by passing None.
|
|
||||||
Note that functions and types declared by 'ffi.cdef()' are not
|
|
||||||
linked to a particular library, just like C headers; in the
|
|
||||||
library we only look for the actual (untyped) symbols.
|
|
||||||
"""
|
|
||||||
if not (isinstance(name, basestring) or
|
|
||||||
name is None or
|
|
||||||
isinstance(name, self.CData)):
|
|
||||||
raise TypeError("dlopen(name): name must be a file name, None, "
|
|
||||||
"or an already-opened 'void *' handle")
|
|
||||||
with self._lock:
|
|
||||||
lib, function_cache = _make_ffi_library(self, name, flags)
|
|
||||||
self._function_caches.append(function_cache)
|
|
||||||
self._libraries.append(lib)
|
|
||||||
return lib
|
|
||||||
|
|
||||||
def dlclose(self, lib):
|
|
||||||
"""Close a library obtained with ffi.dlopen(). After this call,
|
|
||||||
access to functions or variables from the library will fail
|
|
||||||
(possibly with a segmentation fault).
|
|
||||||
"""
|
|
||||||
type(lib).__cffi_close__(lib)
|
|
||||||
|
|
||||||
def _typeof_locked(self, cdecl):
|
|
||||||
# call me with the lock!
|
|
||||||
key = cdecl
|
|
||||||
if key in self._parsed_types:
|
|
||||||
return self._parsed_types[key]
|
|
||||||
#
|
|
||||||
if not isinstance(cdecl, str): # unicode, on Python 2
|
|
||||||
cdecl = cdecl.encode('ascii')
|
|
||||||
#
|
|
||||||
type = self._parser.parse_type(cdecl)
|
|
||||||
really_a_function_type = type.is_raw_function
|
|
||||||
if really_a_function_type:
|
|
||||||
type = type.as_function_pointer()
|
|
||||||
btype = self._get_cached_btype(type)
|
|
||||||
result = btype, really_a_function_type
|
|
||||||
self._parsed_types[key] = result
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _typeof(self, cdecl, consider_function_as_funcptr=False):
|
|
||||||
# string -> ctype object
|
|
||||||
try:
|
|
||||||
result = self._parsed_types[cdecl]
|
|
||||||
except KeyError:
|
|
||||||
with self._lock:
|
|
||||||
result = self._typeof_locked(cdecl)
|
|
||||||
#
|
|
||||||
btype, really_a_function_type = result
|
|
||||||
if really_a_function_type and not consider_function_as_funcptr:
|
|
||||||
raise CDefError("the type %r is a function type, not a "
|
|
||||||
"pointer-to-function type" % (cdecl,))
|
|
||||||
return btype
|
|
||||||
|
|
||||||
def typeof(self, cdecl):
|
|
||||||
"""Parse the C type given as a string and return the
|
|
||||||
corresponding <ctype> object.
|
|
||||||
It can also be used on 'cdata' instance to get its C type.
|
|
||||||
"""
|
|
||||||
if isinstance(cdecl, basestring):
|
|
||||||
return self._typeof(cdecl)
|
|
||||||
if isinstance(cdecl, self.CData):
|
|
||||||
return self._backend.typeof(cdecl)
|
|
||||||
if isinstance(cdecl, types.BuiltinFunctionType):
|
|
||||||
res = _builtin_function_type(cdecl)
|
|
||||||
if res is not None:
|
|
||||||
return res
|
|
||||||
if (isinstance(cdecl, types.FunctionType)
|
|
||||||
and hasattr(cdecl, '_cffi_base_type')):
|
|
||||||
with self._lock:
|
|
||||||
return self._get_cached_btype(cdecl._cffi_base_type)
|
|
||||||
raise TypeError(type(cdecl))
|
|
||||||
|
|
||||||
def sizeof(self, cdecl):
|
|
||||||
"""Return the size in bytes of the argument. It can be a
|
|
||||||
string naming a C type, or a 'cdata' instance.
|
|
||||||
"""
|
|
||||||
if isinstance(cdecl, basestring):
|
|
||||||
BType = self._typeof(cdecl)
|
|
||||||
return self._backend.sizeof(BType)
|
|
||||||
else:
|
|
||||||
return self._backend.sizeof(cdecl)
|
|
||||||
|
|
||||||
def alignof(self, cdecl):
|
|
||||||
"""Return the natural alignment size in bytes of the C type
|
|
||||||
given as a string.
|
|
||||||
"""
|
|
||||||
if isinstance(cdecl, basestring):
|
|
||||||
cdecl = self._typeof(cdecl)
|
|
||||||
return self._backend.alignof(cdecl)
|
|
||||||
|
|
||||||
def offsetof(self, cdecl, *fields_or_indexes):
|
|
||||||
"""Return the offset of the named field inside the given
|
|
||||||
structure or array, which must be given as a C type name.
|
|
||||||
You can give several field names in case of nested structures.
|
|
||||||
You can also give numeric values which correspond to array
|
|
||||||
items, in case of an array type.
|
|
||||||
"""
|
|
||||||
if isinstance(cdecl, basestring):
|
|
||||||
cdecl = self._typeof(cdecl)
|
|
||||||
return self._typeoffsetof(cdecl, *fields_or_indexes)[1]
|
|
||||||
|
|
||||||
def new(self, cdecl, init=None):
|
|
||||||
"""Allocate an instance according to the specified C type and
|
|
||||||
return a pointer to it. The specified C type must be either a
|
|
||||||
pointer or an array: ``new('X *')`` allocates an X and returns
|
|
||||||
a pointer to it, whereas ``new('X[n]')`` allocates an array of
|
|
||||||
n X'es and returns an array referencing it (which works
|
|
||||||
mostly like a pointer, like in C). You can also use
|
|
||||||
``new('X[]', n)`` to allocate an array of a non-constant
|
|
||||||
length n.
|
|
||||||
|
|
||||||
The memory is initialized following the rules of declaring a
|
|
||||||
global variable in C: by default it is zero-initialized, but
|
|
||||||
an explicit initializer can be given which can be used to
|
|
||||||
fill all or part of the memory.
|
|
||||||
|
|
||||||
When the returned <cdata> object goes out of scope, the memory
|
|
||||||
is freed. In other words the returned <cdata> object has
|
|
||||||
ownership of the value of type 'cdecl' that it points to. This
|
|
||||||
means that the raw data can be used as long as this object is
|
|
||||||
kept alive, but must not be used for a longer time. Be careful
|
|
||||||
about that when copying the pointer to the memory somewhere
|
|
||||||
else, e.g. into another structure.
|
|
||||||
"""
|
|
||||||
if isinstance(cdecl, basestring):
|
|
||||||
cdecl = self._typeof(cdecl)
|
|
||||||
return self._backend.newp(cdecl, init)
|
|
||||||
|
|
||||||
def new_allocator(self, alloc=None, free=None,
|
|
||||||
should_clear_after_alloc=True):
|
|
||||||
"""Return a new allocator, i.e. a function that behaves like ffi.new()
|
|
||||||
but uses the provided low-level 'alloc' and 'free' functions.
|
|
||||||
|
|
||||||
'alloc' is called with the size as argument. If it returns NULL, a
|
|
||||||
MemoryError is raised. 'free' is called with the result of 'alloc'
|
|
||||||
as argument. Both can be either Python function or directly C
|
|
||||||
functions. If 'free' is None, then no free function is called.
|
|
||||||
If both 'alloc' and 'free' are None, the default is used.
|
|
||||||
|
|
||||||
If 'should_clear_after_alloc' is set to False, then the memory
|
|
||||||
returned by 'alloc' is assumed to be already cleared (or you are
|
|
||||||
fine with garbage); otherwise CFFI will clear it.
|
|
||||||
"""
|
|
||||||
compiled_ffi = self._backend.FFI()
|
|
||||||
allocator = compiled_ffi.new_allocator(alloc, free,
|
|
||||||
should_clear_after_alloc)
|
|
||||||
def allocate(cdecl, init=None):
|
|
||||||
if isinstance(cdecl, basestring):
|
|
||||||
cdecl = self._typeof(cdecl)
|
|
||||||
return allocator(cdecl, init)
|
|
||||||
return allocate
|
|
||||||
|
|
||||||
def cast(self, cdecl, source):
|
|
||||||
"""Similar to a C cast: returns an instance of the named C
|
|
||||||
type initialized with the given 'source'. The source is
|
|
||||||
casted between integers or pointers of any type.
|
|
||||||
"""
|
|
||||||
if isinstance(cdecl, basestring):
|
|
||||||
cdecl = self._typeof(cdecl)
|
|
||||||
return self._backend.cast(cdecl, source)
|
|
||||||
|
|
||||||
def string(self, cdata, maxlen=-1):
|
|
||||||
"""Return a Python string (or unicode string) from the 'cdata'.
|
|
||||||
If 'cdata' is a pointer or array of characters or bytes, returns
|
|
||||||
the null-terminated string. The returned string extends until
|
|
||||||
the first null character, or at most 'maxlen' characters. If
|
|
||||||
'cdata' is an array then 'maxlen' defaults to its length.
|
|
||||||
|
|
||||||
If 'cdata' is a pointer or array of wchar_t, returns a unicode
|
|
||||||
string following the same rules.
|
|
||||||
|
|
||||||
If 'cdata' is a single character or byte or a wchar_t, returns
|
|
||||||
it as a string or unicode string.
|
|
||||||
|
|
||||||
If 'cdata' is an enum, returns the value of the enumerator as a
|
|
||||||
string, or 'NUMBER' if the value is out of range.
|
|
||||||
"""
|
|
||||||
return self._backend.string(cdata, maxlen)
|
|
||||||
|
|
||||||
def unpack(self, cdata, length):
|
|
||||||
"""Unpack an array of C data of the given length,
|
|
||||||
returning a Python string/unicode/list.
|
|
||||||
|
|
||||||
If 'cdata' is a pointer to 'char', returns a byte string.
|
|
||||||
It does not stop at the first null. This is equivalent to:
|
|
||||||
ffi.buffer(cdata, length)[:]
|
|
||||||
|
|
||||||
If 'cdata' is a pointer to 'wchar_t', returns a unicode string.
|
|
||||||
'length' is measured in wchar_t's; it is not the size in bytes.
|
|
||||||
|
|
||||||
If 'cdata' is a pointer to anything else, returns a list of
|
|
||||||
'length' items. This is a faster equivalent to:
|
|
||||||
[cdata[i] for i in range(length)]
|
|
||||||
"""
|
|
||||||
return self._backend.unpack(cdata, length)
|
|
||||||
|
|
||||||
#def buffer(self, cdata, size=-1):
|
|
||||||
# """Return a read-write buffer object that references the raw C data
|
|
||||||
# pointed to by the given 'cdata'. The 'cdata' must be a pointer or
|
|
||||||
# an array. Can be passed to functions expecting a buffer, or directly
|
|
||||||
# manipulated with:
|
|
||||||
#
|
|
||||||
# buf[:] get a copy of it in a regular string, or
|
|
||||||
# buf[idx] as a single character
|
|
||||||
# buf[:] = ...
|
|
||||||
# buf[idx] = ... change the content
|
|
||||||
# """
|
|
||||||
# note that 'buffer' is a type, set on this instance by __init__
|
|
||||||
|
|
||||||
def from_buffer(self, cdecl, python_buffer=_unspecified,
|
|
||||||
require_writable=False):
|
|
||||||
"""Return a cdata of the given type pointing to the data of the
|
|
||||||
given Python object, which must support the buffer interface.
|
|
||||||
Note that this is not meant to be used on the built-in types
|
|
||||||
str or unicode (you can build 'char[]' arrays explicitly)
|
|
||||||
but only on objects containing large quantities of raw data
|
|
||||||
in some other format, like 'array.array' or numpy arrays.
|
|
||||||
|
|
||||||
The first argument is optional and default to 'char[]'.
|
|
||||||
"""
|
|
||||||
if python_buffer is _unspecified:
|
|
||||||
cdecl, python_buffer = self.BCharA, cdecl
|
|
||||||
elif isinstance(cdecl, basestring):
|
|
||||||
cdecl = self._typeof(cdecl)
|
|
||||||
return self._backend.from_buffer(cdecl, python_buffer,
|
|
||||||
require_writable)
|
|
||||||
|
|
||||||
def memmove(self, dest, src, n):
|
|
||||||
"""ffi.memmove(dest, src, n) copies n bytes of memory from src to dest.
|
|
||||||
|
|
||||||
Like the C function memmove(), the memory areas may overlap;
|
|
||||||
apart from that it behaves like the C function memcpy().
|
|
||||||
|
|
||||||
'src' can be any cdata ptr or array, or any Python buffer object.
|
|
||||||
'dest' can be any cdata ptr or array, or a writable Python buffer
|
|
||||||
object. The size to copy, 'n', is always measured in bytes.
|
|
||||||
|
|
||||||
Unlike other methods, this one supports all Python buffer including
|
|
||||||
byte strings and bytearrays---but it still does not support
|
|
||||||
non-contiguous buffers.
|
|
||||||
"""
|
|
||||||
return self._backend.memmove(dest, src, n)
|
|
||||||
|
|
||||||
def callback(self, cdecl, python_callable=None, error=None, onerror=None):
|
|
||||||
"""Return a callback object or a decorator making such a
|
|
||||||
callback object. 'cdecl' must name a C function pointer type.
|
|
||||||
The callback invokes the specified 'python_callable' (which may
|
|
||||||
be provided either directly or via a decorator). Important: the
|
|
||||||
callback object must be manually kept alive for as long as the
|
|
||||||
callback may be invoked from the C level.
|
|
||||||
"""
|
|
||||||
def callback_decorator_wrap(python_callable):
|
|
||||||
if not callable(python_callable):
|
|
||||||
raise TypeError("the 'python_callable' argument "
|
|
||||||
"is not callable")
|
|
||||||
return self._backend.callback(cdecl, python_callable,
|
|
||||||
error, onerror)
|
|
||||||
if isinstance(cdecl, basestring):
|
|
||||||
cdecl = self._typeof(cdecl, consider_function_as_funcptr=True)
|
|
||||||
if python_callable is None:
|
|
||||||
return callback_decorator_wrap # decorator mode
|
|
||||||
else:
|
|
||||||
return callback_decorator_wrap(python_callable) # direct mode
|
|
||||||
|
|
||||||
def getctype(self, cdecl, replace_with=''):
|
|
||||||
"""Return a string giving the C type 'cdecl', which may be itself
|
|
||||||
a string or a <ctype> object. If 'replace_with' is given, it gives
|
|
||||||
extra text to append (or insert for more complicated C types), like
|
|
||||||
a variable name, or '*' to get actually the C type 'pointer-to-cdecl'.
|
|
||||||
"""
|
|
||||||
if isinstance(cdecl, basestring):
|
|
||||||
cdecl = self._typeof(cdecl)
|
|
||||||
replace_with = replace_with.strip()
|
|
||||||
if (replace_with.startswith('*')
|
|
||||||
and '&[' in self._backend.getcname(cdecl, '&')):
|
|
||||||
replace_with = '(%s)' % replace_with
|
|
||||||
elif replace_with and not replace_with[0] in '[(':
|
|
||||||
replace_with = ' ' + replace_with
|
|
||||||
return self._backend.getcname(cdecl, replace_with)
|
|
||||||
|
|
||||||
def gc(self, cdata, destructor, size=0):
|
|
||||||
"""Return a new cdata object that points to the same
|
|
||||||
data. Later, when this new cdata object is garbage-collected,
|
|
||||||
'destructor(old_cdata_object)' will be called.
|
|
||||||
|
|
||||||
The optional 'size' gives an estimate of the size, used to
|
|
||||||
trigger the garbage collection more eagerly. So far only used
|
|
||||||
on PyPy. It tells the GC that the returned object keeps alive
|
|
||||||
roughly 'size' bytes of external memory.
|
|
||||||
"""
|
|
||||||
return self._backend.gcp(cdata, destructor, size)
|
|
||||||
|
|
||||||
def _get_cached_btype(self, type):
|
|
||||||
assert self._lock.acquire(False) is False
|
|
||||||
# call me with the lock!
|
|
||||||
try:
|
|
||||||
BType = self._cached_btypes[type]
|
|
||||||
except KeyError:
|
|
||||||
finishlist = []
|
|
||||||
BType = type.get_cached_btype(self, finishlist)
|
|
||||||
for type in finishlist:
|
|
||||||
type.finish_backend_type(self, finishlist)
|
|
||||||
return BType
|
|
||||||
|
|
||||||
def verify(self, source='', tmpdir=None, **kwargs):
|
|
||||||
"""Verify that the current ffi signatures compile on this
|
|
||||||
machine, and return a dynamic library object. The dynamic
|
|
||||||
library can be used to call functions and access global
|
|
||||||
variables declared in this 'ffi'. The library is compiled
|
|
||||||
by the C compiler: it gives you C-level API compatibility
|
|
||||||
(including calling macros). This is unlike 'ffi.dlopen()',
|
|
||||||
which requires binary compatibility in the signatures.
|
|
||||||
"""
|
|
||||||
from .verifier import Verifier, _caller_dir_pycache
|
|
||||||
#
|
|
||||||
# If set_unicode(True) was called, insert the UNICODE and
|
|
||||||
# _UNICODE macro declarations
|
|
||||||
if self._windows_unicode:
|
|
||||||
self._apply_windows_unicode(kwargs)
|
|
||||||
#
|
|
||||||
# Set the tmpdir here, and not in Verifier.__init__: it picks
|
|
||||||
# up the caller's directory, which we want to be the caller of
|
|
||||||
# ffi.verify(), as opposed to the caller of Veritier().
|
|
||||||
tmpdir = tmpdir or _caller_dir_pycache()
|
|
||||||
#
|
|
||||||
# Make a Verifier() and use it to load the library.
|
|
||||||
self.verifier = Verifier(self, source, tmpdir, **kwargs)
|
|
||||||
lib = self.verifier.load_library()
|
|
||||||
#
|
|
||||||
# Save the loaded library for keep-alive purposes, even
|
|
||||||
# if the caller doesn't keep it alive itself (it should).
|
|
||||||
self._libraries.append(lib)
|
|
||||||
return lib
|
|
||||||
|
|
||||||
def _get_errno(self):
|
|
||||||
return self._backend.get_errno()
|
|
||||||
def _set_errno(self, errno):
|
|
||||||
self._backend.set_errno(errno)
|
|
||||||
errno = property(_get_errno, _set_errno, None,
|
|
||||||
"the value of 'errno' from/to the C calls")
|
|
||||||
|
|
||||||
def getwinerror(self, code=-1):
|
|
||||||
return self._backend.getwinerror(code)
|
|
||||||
|
|
||||||
def _pointer_to(self, ctype):
|
|
||||||
with self._lock:
|
|
||||||
return model.pointer_cache(self, ctype)
|
|
||||||
|
|
||||||
def addressof(self, cdata, *fields_or_indexes):
|
|
||||||
"""Return the address of a <cdata 'struct-or-union'>.
|
|
||||||
If 'fields_or_indexes' are given, returns the address of that
|
|
||||||
field or array item in the structure or array, recursively in
|
|
||||||
case of nested structures.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
ctype = self._backend.typeof(cdata)
|
|
||||||
except TypeError:
|
|
||||||
if '__addressof__' in type(cdata).__dict__:
|
|
||||||
return type(cdata).__addressof__(cdata, *fields_or_indexes)
|
|
||||||
raise
|
|
||||||
if fields_or_indexes:
|
|
||||||
ctype, offset = self._typeoffsetof(ctype, *fields_or_indexes)
|
|
||||||
else:
|
|
||||||
if ctype.kind == "pointer":
|
|
||||||
raise TypeError("addressof(pointer)")
|
|
||||||
offset = 0
|
|
||||||
ctypeptr = self._pointer_to(ctype)
|
|
||||||
return self._backend.rawaddressof(ctypeptr, cdata, offset)
|
|
||||||
|
|
||||||
def _typeoffsetof(self, ctype, field_or_index, *fields_or_indexes):
|
|
||||||
ctype, offset = self._backend.typeoffsetof(ctype, field_or_index)
|
|
||||||
for field1 in fields_or_indexes:
|
|
||||||
ctype, offset1 = self._backend.typeoffsetof(ctype, field1, 1)
|
|
||||||
offset += offset1
|
|
||||||
return ctype, offset
|
|
||||||
|
|
||||||
def include(self, ffi_to_include):
|
|
||||||
"""Includes the typedefs, structs, unions and enums defined
|
|
||||||
in another FFI instance. Usage is similar to a #include in C,
|
|
||||||
where a part of the program might include types defined in
|
|
||||||
another part for its own usage. Note that the include()
|
|
||||||
method has no effect on functions, constants and global
|
|
||||||
variables, which must anyway be accessed directly from the
|
|
||||||
lib object returned by the original FFI instance.
|
|
||||||
"""
|
|
||||||
if not isinstance(ffi_to_include, FFI):
|
|
||||||
raise TypeError("ffi.include() expects an argument that is also of"
|
|
||||||
" type cffi.FFI, not %r" % (
|
|
||||||
type(ffi_to_include).__name__,))
|
|
||||||
if ffi_to_include is self:
|
|
||||||
raise ValueError("self.include(self)")
|
|
||||||
with ffi_to_include._lock:
|
|
||||||
with self._lock:
|
|
||||||
self._parser.include(ffi_to_include._parser)
|
|
||||||
self._cdefsources.append('[')
|
|
||||||
self._cdefsources.extend(ffi_to_include._cdefsources)
|
|
||||||
self._cdefsources.append(']')
|
|
||||||
self._included_ffis.append(ffi_to_include)
|
|
||||||
|
|
||||||
def new_handle(self, x):
|
|
||||||
return self._backend.newp_handle(self.BVoidP, x)
|
|
||||||
|
|
||||||
def from_handle(self, x):
|
|
||||||
return self._backend.from_handle(x)
|
|
||||||
|
|
||||||
def release(self, x):
|
|
||||||
self._backend.release(x)
|
|
||||||
|
|
||||||
def set_unicode(self, enabled_flag):
|
|
||||||
"""Windows: if 'enabled_flag' is True, enable the UNICODE and
|
|
||||||
_UNICODE defines in C, and declare the types like TCHAR and LPTCSTR
|
|
||||||
to be (pointers to) wchar_t. If 'enabled_flag' is False,
|
|
||||||
declare these types to be (pointers to) plain 8-bit characters.
|
|
||||||
This is mostly for backward compatibility; you usually want True.
|
|
||||||
"""
|
|
||||||
if self._windows_unicode is not None:
|
|
||||||
raise ValueError("set_unicode() can only be called once")
|
|
||||||
enabled_flag = bool(enabled_flag)
|
|
||||||
if enabled_flag:
|
|
||||||
self.cdef("typedef wchar_t TBYTE;"
|
|
||||||
"typedef wchar_t TCHAR;"
|
|
||||||
"typedef const wchar_t *LPCTSTR;"
|
|
||||||
"typedef const wchar_t *PCTSTR;"
|
|
||||||
"typedef wchar_t *LPTSTR;"
|
|
||||||
"typedef wchar_t *PTSTR;"
|
|
||||||
"typedef TBYTE *PTBYTE;"
|
|
||||||
"typedef TCHAR *PTCHAR;")
|
|
||||||
else:
|
|
||||||
self.cdef("typedef char TBYTE;"
|
|
||||||
"typedef char TCHAR;"
|
|
||||||
"typedef const char *LPCTSTR;"
|
|
||||||
"typedef const char *PCTSTR;"
|
|
||||||
"typedef char *LPTSTR;"
|
|
||||||
"typedef char *PTSTR;"
|
|
||||||
"typedef TBYTE *PTBYTE;"
|
|
||||||
"typedef TCHAR *PTCHAR;")
|
|
||||||
self._windows_unicode = enabled_flag
|
|
||||||
|
|
||||||
def _apply_windows_unicode(self, kwds):
|
|
||||||
defmacros = kwds.get('define_macros', ())
|
|
||||||
if not isinstance(defmacros, (list, tuple)):
|
|
||||||
raise TypeError("'define_macros' must be a list or tuple")
|
|
||||||
defmacros = list(defmacros) + [('UNICODE', '1'),
|
|
||||||
('_UNICODE', '1')]
|
|
||||||
kwds['define_macros'] = defmacros
|
|
||||||
|
|
||||||
def _apply_embedding_fix(self, kwds):
|
|
||||||
# must include an argument like "-lpython2.7" for the compiler
|
|
||||||
def ensure(key, value):
|
|
||||||
lst = kwds.setdefault(key, [])
|
|
||||||
if value not in lst:
|
|
||||||
lst.append(value)
|
|
||||||
#
|
|
||||||
if '__pypy__' in sys.builtin_module_names:
|
|
||||||
import os
|
|
||||||
if sys.platform == "win32":
|
|
||||||
# we need 'libpypy-c.lib'. Current distributions of
|
|
||||||
# pypy (>= 4.1) contain it as 'libs/python27.lib'.
|
|
||||||
pythonlib = "python{0[0]}{0[1]}".format(sys.version_info)
|
|
||||||
if hasattr(sys, 'prefix'):
|
|
||||||
ensure('library_dirs', os.path.join(sys.prefix, 'libs'))
|
|
||||||
else:
|
|
||||||
# we need 'libpypy-c.{so,dylib}', which should be by
|
|
||||||
# default located in 'sys.prefix/bin' for installed
|
|
||||||
# systems.
|
|
||||||
if sys.version_info < (3,):
|
|
||||||
pythonlib = "pypy-c"
|
|
||||||
else:
|
|
||||||
pythonlib = "pypy3-c"
|
|
||||||
if hasattr(sys, 'prefix'):
|
|
||||||
ensure('library_dirs', os.path.join(sys.prefix, 'bin'))
|
|
||||||
# On uninstalled pypy's, the libpypy-c is typically found in
|
|
||||||
# .../pypy/goal/.
|
|
||||||
if hasattr(sys, 'prefix'):
|
|
||||||
ensure('library_dirs', os.path.join(sys.prefix, 'pypy', 'goal'))
|
|
||||||
else:
|
|
||||||
if sys.platform == "win32":
|
|
||||||
template = "python%d%d"
|
|
||||||
if hasattr(sys, 'gettotalrefcount'):
|
|
||||||
template += '_d'
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
import sysconfig
|
|
||||||
except ImportError: # 2.6
|
|
||||||
from distutils import sysconfig
|
|
||||||
template = "python%d.%d"
|
|
||||||
if sysconfig.get_config_var('DEBUG_EXT'):
|
|
||||||
template += sysconfig.get_config_var('DEBUG_EXT')
|
|
||||||
pythonlib = (template %
|
|
||||||
(sys.hexversion >> 24, (sys.hexversion >> 16) & 0xff))
|
|
||||||
if hasattr(sys, 'abiflags'):
|
|
||||||
pythonlib += sys.abiflags
|
|
||||||
ensure('libraries', pythonlib)
|
|
||||||
if sys.platform == "win32":
|
|
||||||
ensure('extra_link_args', '/MANIFEST')
|
|
||||||
|
|
||||||
def set_source(self, module_name, source, source_extension='.c', **kwds):
|
|
||||||
import os
|
|
||||||
if hasattr(self, '_assigned_source'):
|
|
||||||
raise ValueError("set_source() cannot be called several times "
|
|
||||||
"per ffi object")
|
|
||||||
if not isinstance(module_name, basestring):
|
|
||||||
raise TypeError("'module_name' must be a string")
|
|
||||||
if os.sep in module_name or (os.altsep and os.altsep in module_name):
|
|
||||||
raise ValueError("'module_name' must not contain '/': use a dotted "
|
|
||||||
"name to make a 'package.module' location")
|
|
||||||
self._assigned_source = (str(module_name), source,
|
|
||||||
source_extension, kwds)
|
|
||||||
|
|
||||||
def set_source_pkgconfig(self, module_name, pkgconfig_libs, source,
|
|
||||||
source_extension='.c', **kwds):
|
|
||||||
from . import pkgconfig
|
|
||||||
if not isinstance(pkgconfig_libs, list):
|
|
||||||
raise TypeError("the pkgconfig_libs argument must be a list "
|
|
||||||
"of package names")
|
|
||||||
kwds2 = pkgconfig.flags_from_pkgconfig(pkgconfig_libs)
|
|
||||||
pkgconfig.merge_flags(kwds, kwds2)
|
|
||||||
self.set_source(module_name, source, source_extension, **kwds)
|
|
||||||
|
|
||||||
def distutils_extension(self, tmpdir='build', verbose=True):
|
|
||||||
from distutils.dir_util import mkpath
|
|
||||||
from .recompiler import recompile
|
|
||||||
#
|
|
||||||
if not hasattr(self, '_assigned_source'):
|
|
||||||
if hasattr(self, 'verifier'): # fallback, 'tmpdir' ignored
|
|
||||||
return self.verifier.get_extension()
|
|
||||||
raise ValueError("set_source() must be called before"
|
|
||||||
" distutils_extension()")
|
|
||||||
module_name, source, source_extension, kwds = self._assigned_source
|
|
||||||
if source is None:
|
|
||||||
raise TypeError("distutils_extension() is only for C extension "
|
|
||||||
"modules, not for dlopen()-style pure Python "
|
|
||||||
"modules")
|
|
||||||
mkpath(tmpdir)
|
|
||||||
ext, updated = recompile(self, module_name,
|
|
||||||
source, tmpdir=tmpdir, extradir=tmpdir,
|
|
||||||
source_extension=source_extension,
|
|
||||||
call_c_compiler=False, **kwds)
|
|
||||||
if verbose:
|
|
||||||
if updated:
|
|
||||||
sys.stderr.write("regenerated: %r\n" % (ext.sources[0],))
|
|
||||||
else:
|
|
||||||
sys.stderr.write("not modified: %r\n" % (ext.sources[0],))
|
|
||||||
return ext
|
|
||||||
|
|
||||||
def emit_c_code(self, filename):
|
|
||||||
from .recompiler import recompile
|
|
||||||
#
|
|
||||||
if not hasattr(self, '_assigned_source'):
|
|
||||||
raise ValueError("set_source() must be called before emit_c_code()")
|
|
||||||
module_name, source, source_extension, kwds = self._assigned_source
|
|
||||||
if source is None:
|
|
||||||
raise TypeError("emit_c_code() is only for C extension modules, "
|
|
||||||
"not for dlopen()-style pure Python modules")
|
|
||||||
recompile(self, module_name, source,
|
|
||||||
c_file=filename, call_c_compiler=False, **kwds)
|
|
||||||
|
|
||||||
def emit_python_code(self, filename):
|
|
||||||
from .recompiler import recompile
|
|
||||||
#
|
|
||||||
if not hasattr(self, '_assigned_source'):
|
|
||||||
raise ValueError("set_source() must be called before emit_c_code()")
|
|
||||||
module_name, source, source_extension, kwds = self._assigned_source
|
|
||||||
if source is not None:
|
|
||||||
raise TypeError("emit_python_code() is only for dlopen()-style "
|
|
||||||
"pure Python modules, not for C extension modules")
|
|
||||||
recompile(self, module_name, source,
|
|
||||||
c_file=filename, call_c_compiler=False, **kwds)
|
|
||||||
|
|
||||||
def compile(self, tmpdir='.', verbose=0, target=None, debug=None):
|
|
||||||
"""The 'target' argument gives the final file name of the
|
|
||||||
compiled DLL. Use '*' to force distutils' choice, suitable for
|
|
||||||
regular CPython C API modules. Use a file name ending in '.*'
|
|
||||||
to ask for the system's default extension for dynamic libraries
|
|
||||||
(.so/.dll/.dylib).
|
|
||||||
|
|
||||||
The default is '*' when building a non-embedded C API extension,
|
|
||||||
and (module_name + '.*') when building an embedded library.
|
|
||||||
"""
|
|
||||||
from .recompiler import recompile
|
|
||||||
#
|
|
||||||
if not hasattr(self, '_assigned_source'):
|
|
||||||
raise ValueError("set_source() must be called before compile()")
|
|
||||||
module_name, source, source_extension, kwds = self._assigned_source
|
|
||||||
return recompile(self, module_name, source, tmpdir=tmpdir,
|
|
||||||
target=target, source_extension=source_extension,
|
|
||||||
compiler_verbose=verbose, debug=debug, **kwds)
|
|
||||||
|
|
||||||
def init_once(self, func, tag):
|
|
||||||
# Read _init_once_cache[tag], which is either (False, lock) if
|
|
||||||
# we're calling the function now in some thread, or (True, result).
|
|
||||||
# Don't call setdefault() in most cases, to avoid allocating and
|
|
||||||
# immediately freeing a lock; but still use setdefaut() to avoid
|
|
||||||
# races.
|
|
||||||
try:
|
|
||||||
x = self._init_once_cache[tag]
|
|
||||||
except KeyError:
|
|
||||||
x = self._init_once_cache.setdefault(tag, (False, allocate_lock()))
|
|
||||||
# Common case: we got (True, result), so we return the result.
|
|
||||||
if x[0]:
|
|
||||||
return x[1]
|
|
||||||
# Else, it's a lock. Acquire it to serialize the following tests.
|
|
||||||
with x[1]:
|
|
||||||
# Read again from _init_once_cache the current status.
|
|
||||||
x = self._init_once_cache[tag]
|
|
||||||
if x[0]:
|
|
||||||
return x[1]
|
|
||||||
# Call the function and store the result back.
|
|
||||||
result = func()
|
|
||||||
self._init_once_cache[tag] = (True, result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def embedding_init_code(self, pysource):
|
|
||||||
if self._embedding:
|
|
||||||
raise ValueError("embedding_init_code() can only be called once")
|
|
||||||
# fix 'pysource' before it gets dumped into the C file:
|
|
||||||
# - remove empty lines at the beginning, so it starts at "line 1"
|
|
||||||
# - dedent, if all non-empty lines are indented
|
|
||||||
# - check for SyntaxErrors
|
|
||||||
import re
|
|
||||||
match = re.match(r'\s*\n', pysource)
|
|
||||||
if match:
|
|
||||||
pysource = pysource[match.end():]
|
|
||||||
lines = pysource.splitlines() or ['']
|
|
||||||
prefix = re.match(r'\s*', lines[0]).group()
|
|
||||||
for i in range(1, len(lines)):
|
|
||||||
line = lines[i]
|
|
||||||
if line.rstrip():
|
|
||||||
while not line.startswith(prefix):
|
|
||||||
prefix = prefix[:-1]
|
|
||||||
i = len(prefix)
|
|
||||||
lines = [line[i:]+'\n' for line in lines]
|
|
||||||
pysource = ''.join(lines)
|
|
||||||
#
|
|
||||||
compile(pysource, "cffi_init", "exec")
|
|
||||||
#
|
|
||||||
self._embedding = pysource
|
|
||||||
|
|
||||||
def def_extern(self, *args, **kwds):
|
|
||||||
raise ValueError("ffi.def_extern() is only available on API-mode FFI "
|
|
||||||
"objects")
|
|
||||||
|
|
||||||
def list_types(self):
|
|
||||||
"""Returns the user type names known to this FFI instance.
|
|
||||||
This returns a tuple containing three lists of names:
|
|
||||||
(typedef_names, names_of_structs, names_of_unions)
|
|
||||||
"""
|
|
||||||
typedefs = []
|
|
||||||
structs = []
|
|
||||||
unions = []
|
|
||||||
for key in self._parser._declarations:
|
|
||||||
if key.startswith('typedef '):
|
|
||||||
typedefs.append(key[8:])
|
|
||||||
elif key.startswith('struct '):
|
|
||||||
structs.append(key[7:])
|
|
||||||
elif key.startswith('union '):
|
|
||||||
unions.append(key[6:])
|
|
||||||
typedefs.sort()
|
|
||||||
structs.sort()
|
|
||||||
unions.sort()
|
|
||||||
return (typedefs, structs, unions)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_backend_lib(backend, name, flags):
|
|
||||||
import os
|
|
||||||
if not isinstance(name, basestring):
|
|
||||||
if sys.platform != "win32" or name is not None:
|
|
||||||
return backend.load_library(name, flags)
|
|
||||||
name = "c" # Windows: load_library(None) fails, but this works
|
|
||||||
# on Python 2 (backward compatibility hack only)
|
|
||||||
first_error = None
|
|
||||||
if '.' in name or '/' in name or os.sep in name:
|
|
||||||
try:
|
|
||||||
return backend.load_library(name, flags)
|
|
||||||
except OSError as e:
|
|
||||||
first_error = e
|
|
||||||
import ctypes.util
|
|
||||||
path = ctypes.util.find_library(name)
|
|
||||||
if path is None:
|
|
||||||
if name == "c" and sys.platform == "win32" and sys.version_info >= (3,):
|
|
||||||
raise OSError("dlopen(None) cannot work on Windows for Python 3 "
|
|
||||||
"(see http://bugs.python.org/issue23606)")
|
|
||||||
msg = ("ctypes.util.find_library() did not manage "
|
|
||||||
"to locate a library called %r" % (name,))
|
|
||||||
if first_error is not None:
|
|
||||||
msg = "%s. Additionally, %s" % (first_error, msg)
|
|
||||||
raise OSError(msg)
|
|
||||||
return backend.load_library(path, flags)
|
|
||||||
|
|
||||||
def _make_ffi_library(ffi, libname, flags):
|
|
||||||
backend = ffi._backend
|
|
||||||
backendlib = _load_backend_lib(backend, libname, flags)
|
|
||||||
#
|
|
||||||
def accessor_function(name):
|
|
||||||
key = 'function ' + name
|
|
||||||
tp, _ = ffi._parser._declarations[key]
|
|
||||||
BType = ffi._get_cached_btype(tp)
|
|
||||||
value = backendlib.load_function(BType, name)
|
|
||||||
library.__dict__[name] = value
|
|
||||||
#
|
|
||||||
def accessor_variable(name):
|
|
||||||
key = 'variable ' + name
|
|
||||||
tp, _ = ffi._parser._declarations[key]
|
|
||||||
BType = ffi._get_cached_btype(tp)
|
|
||||||
read_variable = backendlib.read_variable
|
|
||||||
write_variable = backendlib.write_variable
|
|
||||||
setattr(FFILibrary, name, property(
|
|
||||||
lambda self: read_variable(BType, name),
|
|
||||||
lambda self, value: write_variable(BType, name, value)))
|
|
||||||
#
|
|
||||||
def addressof_var(name):
|
|
||||||
try:
|
|
||||||
return addr_variables[name]
|
|
||||||
except KeyError:
|
|
||||||
with ffi._lock:
|
|
||||||
if name not in addr_variables:
|
|
||||||
key = 'variable ' + name
|
|
||||||
tp, _ = ffi._parser._declarations[key]
|
|
||||||
BType = ffi._get_cached_btype(tp)
|
|
||||||
if BType.kind != 'array':
|
|
||||||
BType = model.pointer_cache(ffi, BType)
|
|
||||||
p = backendlib.load_function(BType, name)
|
|
||||||
addr_variables[name] = p
|
|
||||||
return addr_variables[name]
|
|
||||||
#
|
|
||||||
def accessor_constant(name):
|
|
||||||
raise NotImplementedError("non-integer constant '%s' cannot be "
|
|
||||||
"accessed from a dlopen() library" % (name,))
|
|
||||||
#
|
|
||||||
def accessor_int_constant(name):
|
|
||||||
library.__dict__[name] = ffi._parser._int_constants[name]
|
|
||||||
#
|
|
||||||
accessors = {}
|
|
||||||
accessors_version = [False]
|
|
||||||
addr_variables = {}
|
|
||||||
#
|
|
||||||
def update_accessors():
|
|
||||||
if accessors_version[0] is ffi._cdef_version:
|
|
||||||
return
|
|
||||||
#
|
|
||||||
for key, (tp, _) in ffi._parser._declarations.items():
|
|
||||||
if not isinstance(tp, model.EnumType):
|
|
||||||
tag, name = key.split(' ', 1)
|
|
||||||
if tag == 'function':
|
|
||||||
accessors[name] = accessor_function
|
|
||||||
elif tag == 'variable':
|
|
||||||
accessors[name] = accessor_variable
|
|
||||||
elif tag == 'constant':
|
|
||||||
accessors[name] = accessor_constant
|
|
||||||
else:
|
|
||||||
for i, enumname in enumerate(tp.enumerators):
|
|
||||||
def accessor_enum(name, tp=tp, i=i):
|
|
||||||
tp.check_not_partial()
|
|
||||||
library.__dict__[name] = tp.enumvalues[i]
|
|
||||||
accessors[enumname] = accessor_enum
|
|
||||||
for name in ffi._parser._int_constants:
|
|
||||||
accessors.setdefault(name, accessor_int_constant)
|
|
||||||
accessors_version[0] = ffi._cdef_version
|
|
||||||
#
|
|
||||||
def make_accessor(name):
|
|
||||||
with ffi._lock:
|
|
||||||
if name in library.__dict__ or name in FFILibrary.__dict__:
|
|
||||||
return # added by another thread while waiting for the lock
|
|
||||||
if name not in accessors:
|
|
||||||
update_accessors()
|
|
||||||
if name not in accessors:
|
|
||||||
raise AttributeError(name)
|
|
||||||
accessors[name](name)
|
|
||||||
#
|
|
||||||
class FFILibrary(object):
|
|
||||||
def __getattr__(self, name):
|
|
||||||
make_accessor(name)
|
|
||||||
return getattr(self, name)
|
|
||||||
def __setattr__(self, name, value):
|
|
||||||
try:
|
|
||||||
property = getattr(self.__class__, name)
|
|
||||||
except AttributeError:
|
|
||||||
make_accessor(name)
|
|
||||||
setattr(self, name, value)
|
|
||||||
else:
|
|
||||||
property.__set__(self, value)
|
|
||||||
def __dir__(self):
|
|
||||||
with ffi._lock:
|
|
||||||
update_accessors()
|
|
||||||
return accessors.keys()
|
|
||||||
def __addressof__(self, name):
|
|
||||||
if name in library.__dict__:
|
|
||||||
return library.__dict__[name]
|
|
||||||
if name in FFILibrary.__dict__:
|
|
||||||
return addressof_var(name)
|
|
||||||
make_accessor(name)
|
|
||||||
if name in library.__dict__:
|
|
||||||
return library.__dict__[name]
|
|
||||||
if name in FFILibrary.__dict__:
|
|
||||||
return addressof_var(name)
|
|
||||||
raise AttributeError("cffi library has no function or "
|
|
||||||
"global variable named '%s'" % (name,))
|
|
||||||
def __cffi_close__(self):
|
|
||||||
backendlib.close_lib()
|
|
||||||
self.__dict__.clear()
|
|
||||||
#
|
|
||||||
if isinstance(libname, basestring):
|
|
||||||
try:
|
|
||||||
if not isinstance(libname, str): # unicode, on Python 2
|
|
||||||
libname = libname.encode('utf-8')
|
|
||||||
FFILibrary.__name__ = 'FFILibrary_%s' % libname
|
|
||||||
except UnicodeError:
|
|
||||||
pass
|
|
||||||
library = FFILibrary()
|
|
||||||
return library, library.__dict__
|
|
||||||
|
|
||||||
def _builtin_function_type(func):
|
|
||||||
# a hack to make at least ffi.typeof(builtin_function) work,
|
|
||||||
# if the builtin function was obtained by 'vengine_cpy'.
|
|
||||||
import sys
|
|
||||||
try:
|
|
||||||
module = sys.modules[func.__module__]
|
|
||||||
ffi = module._cffi_original_ffi
|
|
||||||
types_of_builtin_funcs = module._cffi_types_of_builtin_funcs
|
|
||||||
tp = types_of_builtin_funcs[func]
|
|
||||||
except (KeyError, AttributeError, TypeError):
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
with ffi._lock:
|
|
||||||
return ffi._get_cached_btype(tp)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,187 +0,0 @@
|
|||||||
from .error import VerificationError
|
|
||||||
|
|
||||||
class CffiOp(object):
|
|
||||||
def __init__(self, op, arg):
|
|
||||||
self.op = op
|
|
||||||
self.arg = arg
|
|
||||||
|
|
||||||
def as_c_expr(self):
|
|
||||||
if self.op is None:
|
|
||||||
assert isinstance(self.arg, str)
|
|
||||||
return '(_cffi_opcode_t)(%s)' % (self.arg,)
|
|
||||||
classname = CLASS_NAME[self.op]
|
|
||||||
return '_CFFI_OP(_CFFI_OP_%s, %s)' % (classname, self.arg)
|
|
||||||
|
|
||||||
def as_python_bytes(self):
|
|
||||||
if self.op is None and self.arg.isdigit():
|
|
||||||
value = int(self.arg) # non-negative: '-' not in self.arg
|
|
||||||
if value >= 2**31:
|
|
||||||
raise OverflowError("cannot emit %r: limited to 2**31-1"
|
|
||||||
% (self.arg,))
|
|
||||||
return format_four_bytes(value)
|
|
||||||
if isinstance(self.arg, str):
|
|
||||||
raise VerificationError("cannot emit to Python: %r" % (self.arg,))
|
|
||||||
return format_four_bytes((self.arg << 8) | self.op)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
classname = CLASS_NAME.get(self.op, self.op)
|
|
||||||
return '(%s %s)' % (classname, self.arg)
|
|
||||||
|
|
||||||
def format_four_bytes(num):
|
|
||||||
return '\\x%02X\\x%02X\\x%02X\\x%02X' % (
|
|
||||||
(num >> 24) & 0xFF,
|
|
||||||
(num >> 16) & 0xFF,
|
|
||||||
(num >> 8) & 0xFF,
|
|
||||||
(num ) & 0xFF)
|
|
||||||
|
|
||||||
OP_PRIMITIVE = 1
|
|
||||||
OP_POINTER = 3
|
|
||||||
OP_ARRAY = 5
|
|
||||||
OP_OPEN_ARRAY = 7
|
|
||||||
OP_STRUCT_UNION = 9
|
|
||||||
OP_ENUM = 11
|
|
||||||
OP_FUNCTION = 13
|
|
||||||
OP_FUNCTION_END = 15
|
|
||||||
OP_NOOP = 17
|
|
||||||
OP_BITFIELD = 19
|
|
||||||
OP_TYPENAME = 21
|
|
||||||
OP_CPYTHON_BLTN_V = 23 # varargs
|
|
||||||
OP_CPYTHON_BLTN_N = 25 # noargs
|
|
||||||
OP_CPYTHON_BLTN_O = 27 # O (i.e. a single arg)
|
|
||||||
OP_CONSTANT = 29
|
|
||||||
OP_CONSTANT_INT = 31
|
|
||||||
OP_GLOBAL_VAR = 33
|
|
||||||
OP_DLOPEN_FUNC = 35
|
|
||||||
OP_DLOPEN_CONST = 37
|
|
||||||
OP_GLOBAL_VAR_F = 39
|
|
||||||
OP_EXTERN_PYTHON = 41
|
|
||||||
|
|
||||||
PRIM_VOID = 0
|
|
||||||
PRIM_BOOL = 1
|
|
||||||
PRIM_CHAR = 2
|
|
||||||
PRIM_SCHAR = 3
|
|
||||||
PRIM_UCHAR = 4
|
|
||||||
PRIM_SHORT = 5
|
|
||||||
PRIM_USHORT = 6
|
|
||||||
PRIM_INT = 7
|
|
||||||
PRIM_UINT = 8
|
|
||||||
PRIM_LONG = 9
|
|
||||||
PRIM_ULONG = 10
|
|
||||||
PRIM_LONGLONG = 11
|
|
||||||
PRIM_ULONGLONG = 12
|
|
||||||
PRIM_FLOAT = 13
|
|
||||||
PRIM_DOUBLE = 14
|
|
||||||
PRIM_LONGDOUBLE = 15
|
|
||||||
|
|
||||||
PRIM_WCHAR = 16
|
|
||||||
PRIM_INT8 = 17
|
|
||||||
PRIM_UINT8 = 18
|
|
||||||
PRIM_INT16 = 19
|
|
||||||
PRIM_UINT16 = 20
|
|
||||||
PRIM_INT32 = 21
|
|
||||||
PRIM_UINT32 = 22
|
|
||||||
PRIM_INT64 = 23
|
|
||||||
PRIM_UINT64 = 24
|
|
||||||
PRIM_INTPTR = 25
|
|
||||||
PRIM_UINTPTR = 26
|
|
||||||
PRIM_PTRDIFF = 27
|
|
||||||
PRIM_SIZE = 28
|
|
||||||
PRIM_SSIZE = 29
|
|
||||||
PRIM_INT_LEAST8 = 30
|
|
||||||
PRIM_UINT_LEAST8 = 31
|
|
||||||
PRIM_INT_LEAST16 = 32
|
|
||||||
PRIM_UINT_LEAST16 = 33
|
|
||||||
PRIM_INT_LEAST32 = 34
|
|
||||||
PRIM_UINT_LEAST32 = 35
|
|
||||||
PRIM_INT_LEAST64 = 36
|
|
||||||
PRIM_UINT_LEAST64 = 37
|
|
||||||
PRIM_INT_FAST8 = 38
|
|
||||||
PRIM_UINT_FAST8 = 39
|
|
||||||
PRIM_INT_FAST16 = 40
|
|
||||||
PRIM_UINT_FAST16 = 41
|
|
||||||
PRIM_INT_FAST32 = 42
|
|
||||||
PRIM_UINT_FAST32 = 43
|
|
||||||
PRIM_INT_FAST64 = 44
|
|
||||||
PRIM_UINT_FAST64 = 45
|
|
||||||
PRIM_INTMAX = 46
|
|
||||||
PRIM_UINTMAX = 47
|
|
||||||
PRIM_FLOATCOMPLEX = 48
|
|
||||||
PRIM_DOUBLECOMPLEX = 49
|
|
||||||
PRIM_CHAR16 = 50
|
|
||||||
PRIM_CHAR32 = 51
|
|
||||||
|
|
||||||
_NUM_PRIM = 52
|
|
||||||
_UNKNOWN_PRIM = -1
|
|
||||||
_UNKNOWN_FLOAT_PRIM = -2
|
|
||||||
_UNKNOWN_LONG_DOUBLE = -3
|
|
||||||
|
|
||||||
_IO_FILE_STRUCT = -1
|
|
||||||
|
|
||||||
PRIMITIVE_TO_INDEX = {
|
|
||||||
'char': PRIM_CHAR,
|
|
||||||
'short': PRIM_SHORT,
|
|
||||||
'int': PRIM_INT,
|
|
||||||
'long': PRIM_LONG,
|
|
||||||
'long long': PRIM_LONGLONG,
|
|
||||||
'signed char': PRIM_SCHAR,
|
|
||||||
'unsigned char': PRIM_UCHAR,
|
|
||||||
'unsigned short': PRIM_USHORT,
|
|
||||||
'unsigned int': PRIM_UINT,
|
|
||||||
'unsigned long': PRIM_ULONG,
|
|
||||||
'unsigned long long': PRIM_ULONGLONG,
|
|
||||||
'float': PRIM_FLOAT,
|
|
||||||
'double': PRIM_DOUBLE,
|
|
||||||
'long double': PRIM_LONGDOUBLE,
|
|
||||||
'float _Complex': PRIM_FLOATCOMPLEX,
|
|
||||||
'double _Complex': PRIM_DOUBLECOMPLEX,
|
|
||||||
'_Bool': PRIM_BOOL,
|
|
||||||
'wchar_t': PRIM_WCHAR,
|
|
||||||
'char16_t': PRIM_CHAR16,
|
|
||||||
'char32_t': PRIM_CHAR32,
|
|
||||||
'int8_t': PRIM_INT8,
|
|
||||||
'uint8_t': PRIM_UINT8,
|
|
||||||
'int16_t': PRIM_INT16,
|
|
||||||
'uint16_t': PRIM_UINT16,
|
|
||||||
'int32_t': PRIM_INT32,
|
|
||||||
'uint32_t': PRIM_UINT32,
|
|
||||||
'int64_t': PRIM_INT64,
|
|
||||||
'uint64_t': PRIM_UINT64,
|
|
||||||
'intptr_t': PRIM_INTPTR,
|
|
||||||
'uintptr_t': PRIM_UINTPTR,
|
|
||||||
'ptrdiff_t': PRIM_PTRDIFF,
|
|
||||||
'size_t': PRIM_SIZE,
|
|
||||||
'ssize_t': PRIM_SSIZE,
|
|
||||||
'int_least8_t': PRIM_INT_LEAST8,
|
|
||||||
'uint_least8_t': PRIM_UINT_LEAST8,
|
|
||||||
'int_least16_t': PRIM_INT_LEAST16,
|
|
||||||
'uint_least16_t': PRIM_UINT_LEAST16,
|
|
||||||
'int_least32_t': PRIM_INT_LEAST32,
|
|
||||||
'uint_least32_t': PRIM_UINT_LEAST32,
|
|
||||||
'int_least64_t': PRIM_INT_LEAST64,
|
|
||||||
'uint_least64_t': PRIM_UINT_LEAST64,
|
|
||||||
'int_fast8_t': PRIM_INT_FAST8,
|
|
||||||
'uint_fast8_t': PRIM_UINT_FAST8,
|
|
||||||
'int_fast16_t': PRIM_INT_FAST16,
|
|
||||||
'uint_fast16_t': PRIM_UINT_FAST16,
|
|
||||||
'int_fast32_t': PRIM_INT_FAST32,
|
|
||||||
'uint_fast32_t': PRIM_UINT_FAST32,
|
|
||||||
'int_fast64_t': PRIM_INT_FAST64,
|
|
||||||
'uint_fast64_t': PRIM_UINT_FAST64,
|
|
||||||
'intmax_t': PRIM_INTMAX,
|
|
||||||
'uintmax_t': PRIM_UINTMAX,
|
|
||||||
}
|
|
||||||
|
|
||||||
F_UNION = 0x01
|
|
||||||
F_CHECK_FIELDS = 0x02
|
|
||||||
F_PACKED = 0x04
|
|
||||||
F_EXTERNAL = 0x08
|
|
||||||
F_OPAQUE = 0x10
|
|
||||||
|
|
||||||
G_FLAGS = dict([('_CFFI_' + _key, globals()[_key])
|
|
||||||
for _key in ['F_UNION', 'F_CHECK_FIELDS', 'F_PACKED',
|
|
||||||
'F_EXTERNAL', 'F_OPAQUE']])
|
|
||||||
|
|
||||||
CLASS_NAME = {}
|
|
||||||
for _name, _value in list(globals().items()):
|
|
||||||
if _name.startswith('OP_') and isinstance(_value, int):
|
|
||||||
CLASS_NAME[_value] = _name[3:]
|
|
||||||
@@ -1,80 +0,0 @@
|
|||||||
import sys
|
|
||||||
from . import model
|
|
||||||
from .error import FFIError
|
|
||||||
|
|
||||||
|
|
||||||
COMMON_TYPES = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# fetch "bool" and all simple Windows types
|
|
||||||
from _cffi_backend import _get_common_types
|
|
||||||
_get_common_types(COMMON_TYPES)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
COMMON_TYPES['FILE'] = model.unknown_type('FILE', '_IO_FILE')
|
|
||||||
COMMON_TYPES['bool'] = '_Bool' # in case we got ImportError above
|
|
||||||
|
|
||||||
for _type in model.PrimitiveType.ALL_PRIMITIVE_TYPES:
|
|
||||||
if _type.endswith('_t'):
|
|
||||||
COMMON_TYPES[_type] = _type
|
|
||||||
del _type
|
|
||||||
|
|
||||||
_CACHE = {}
|
|
||||||
|
|
||||||
def resolve_common_type(parser, commontype):
|
|
||||||
try:
|
|
||||||
return _CACHE[commontype]
|
|
||||||
except KeyError:
|
|
||||||
cdecl = COMMON_TYPES.get(commontype, commontype)
|
|
||||||
if not isinstance(cdecl, str):
|
|
||||||
result, quals = cdecl, 0 # cdecl is already a BaseType
|
|
||||||
elif cdecl in model.PrimitiveType.ALL_PRIMITIVE_TYPES:
|
|
||||||
result, quals = model.PrimitiveType(cdecl), 0
|
|
||||||
elif cdecl == 'set-unicode-needed':
|
|
||||||
raise FFIError("The Windows type %r is only available after "
|
|
||||||
"you call ffi.set_unicode()" % (commontype,))
|
|
||||||
else:
|
|
||||||
if commontype == cdecl:
|
|
||||||
raise FFIError(
|
|
||||||
"Unsupported type: %r. Please look at "
|
|
||||||
"http://cffi.readthedocs.io/en/latest/cdef.html#ffi-cdef-limitations "
|
|
||||||
"and file an issue if you think this type should really "
|
|
||||||
"be supported." % (commontype,))
|
|
||||||
result, quals = parser.parse_type_and_quals(cdecl) # recursive
|
|
||||||
|
|
||||||
assert isinstance(result, model.BaseTypeByIdentity)
|
|
||||||
_CACHE[commontype] = result, quals
|
|
||||||
return result, quals
|
|
||||||
|
|
||||||
|
|
||||||
# ____________________________________________________________
|
|
||||||
# extra types for Windows (most of them are in commontypes.c)
|
|
||||||
|
|
||||||
|
|
||||||
def win_common_types():
|
|
||||||
return {
|
|
||||||
"UNICODE_STRING": model.StructType(
|
|
||||||
"_UNICODE_STRING",
|
|
||||||
["Length",
|
|
||||||
"MaximumLength",
|
|
||||||
"Buffer"],
|
|
||||||
[model.PrimitiveType("unsigned short"),
|
|
||||||
model.PrimitiveType("unsigned short"),
|
|
||||||
model.PointerType(model.PrimitiveType("wchar_t"))],
|
|
||||||
[-1, -1, -1]),
|
|
||||||
"PUNICODE_STRING": "UNICODE_STRING *",
|
|
||||||
"PCUNICODE_STRING": "const UNICODE_STRING *",
|
|
||||||
|
|
||||||
"TBYTE": "set-unicode-needed",
|
|
||||||
"TCHAR": "set-unicode-needed",
|
|
||||||
"LPCTSTR": "set-unicode-needed",
|
|
||||||
"PCTSTR": "set-unicode-needed",
|
|
||||||
"LPTSTR": "set-unicode-needed",
|
|
||||||
"PTSTR": "set-unicode-needed",
|
|
||||||
"PTBYTE": "set-unicode-needed",
|
|
||||||
"PTCHAR": "set-unicode-needed",
|
|
||||||
}
|
|
||||||
|
|
||||||
if sys.platform == 'win32':
|
|
||||||
COMMON_TYPES.update(win_common_types())
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,31 +0,0 @@
|
|||||||
|
|
||||||
class FFIError(Exception):
|
|
||||||
__module__ = 'cffi'
|
|
||||||
|
|
||||||
class CDefError(Exception):
|
|
||||||
__module__ = 'cffi'
|
|
||||||
def __str__(self):
|
|
||||||
try:
|
|
||||||
current_decl = self.args[1]
|
|
||||||
filename = current_decl.coord.file
|
|
||||||
linenum = current_decl.coord.line
|
|
||||||
prefix = '%s:%d: ' % (filename, linenum)
|
|
||||||
except (AttributeError, TypeError, IndexError):
|
|
||||||
prefix = ''
|
|
||||||
return '%s%s' % (prefix, self.args[0])
|
|
||||||
|
|
||||||
class VerificationError(Exception):
|
|
||||||
""" An error raised when verification fails
|
|
||||||
"""
|
|
||||||
__module__ = 'cffi'
|
|
||||||
|
|
||||||
class VerificationMissing(Exception):
|
|
||||||
""" An error raised when incomplete structures are passed into
|
|
||||||
cdef, but no verification has been done
|
|
||||||
"""
|
|
||||||
__module__ = 'cffi'
|
|
||||||
|
|
||||||
class PkgConfigError(Exception):
|
|
||||||
""" An error raised for missing modules in pkg-config
|
|
||||||
"""
|
|
||||||
__module__ = 'cffi'
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
import sys, os
|
|
||||||
from .error import VerificationError
|
|
||||||
|
|
||||||
|
|
||||||
LIST_OF_FILE_NAMES = ['sources', 'include_dirs', 'library_dirs',
|
|
||||||
'extra_objects', 'depends']
|
|
||||||
|
|
||||||
def get_extension(srcfilename, modname, sources=(), **kwds):
|
|
||||||
_hack_at_distutils()
|
|
||||||
from distutils.core import Extension
|
|
||||||
allsources = [srcfilename]
|
|
||||||
for src in sources:
|
|
||||||
allsources.append(os.path.normpath(src))
|
|
||||||
return Extension(name=modname, sources=allsources, **kwds)
|
|
||||||
|
|
||||||
def compile(tmpdir, ext, compiler_verbose=0, debug=None):
|
|
||||||
"""Compile a C extension module using distutils."""
|
|
||||||
|
|
||||||
_hack_at_distutils()
|
|
||||||
saved_environ = os.environ.copy()
|
|
||||||
try:
|
|
||||||
outputfilename = _build(tmpdir, ext, compiler_verbose, debug)
|
|
||||||
outputfilename = os.path.abspath(outputfilename)
|
|
||||||
finally:
|
|
||||||
# workaround for a distutils bugs where some env vars can
|
|
||||||
# become longer and longer every time it is used
|
|
||||||
for key, value in saved_environ.items():
|
|
||||||
if os.environ.get(key) != value:
|
|
||||||
os.environ[key] = value
|
|
||||||
return outputfilename
|
|
||||||
|
|
||||||
def _build(tmpdir, ext, compiler_verbose=0, debug=None):
|
|
||||||
# XXX compact but horrible :-(
|
|
||||||
from distutils.core import Distribution
|
|
||||||
import distutils.errors, distutils.log
|
|
||||||
#
|
|
||||||
dist = Distribution({'ext_modules': [ext]})
|
|
||||||
dist.parse_config_files()
|
|
||||||
options = dist.get_option_dict('build_ext')
|
|
||||||
if debug is None:
|
|
||||||
debug = sys.flags.debug
|
|
||||||
options['debug'] = ('ffiplatform', debug)
|
|
||||||
options['force'] = ('ffiplatform', True)
|
|
||||||
options['build_lib'] = ('ffiplatform', tmpdir)
|
|
||||||
options['build_temp'] = ('ffiplatform', tmpdir)
|
|
||||||
#
|
|
||||||
try:
|
|
||||||
old_level = distutils.log.set_threshold(0) or 0
|
|
||||||
try:
|
|
||||||
distutils.log.set_verbosity(compiler_verbose)
|
|
||||||
dist.run_command('build_ext')
|
|
||||||
cmd_obj = dist.get_command_obj('build_ext')
|
|
||||||
[soname] = cmd_obj.get_outputs()
|
|
||||||
finally:
|
|
||||||
distutils.log.set_threshold(old_level)
|
|
||||||
except (distutils.errors.CompileError,
|
|
||||||
distutils.errors.LinkError) as e:
|
|
||||||
raise VerificationError('%s: %s' % (e.__class__.__name__, e))
|
|
||||||
#
|
|
||||||
return soname
|
|
||||||
|
|
||||||
try:
|
|
||||||
from os.path import samefile
|
|
||||||
except ImportError:
|
|
||||||
def samefile(f1, f2):
|
|
||||||
return os.path.abspath(f1) == os.path.abspath(f2)
|
|
||||||
|
|
||||||
def maybe_relative_path(path):
|
|
||||||
if not os.path.isabs(path):
|
|
||||||
return path # already relative
|
|
||||||
dir = path
|
|
||||||
names = []
|
|
||||||
while True:
|
|
||||||
prevdir = dir
|
|
||||||
dir, name = os.path.split(prevdir)
|
|
||||||
if dir == prevdir or not dir:
|
|
||||||
return path # failed to make it relative
|
|
||||||
names.append(name)
|
|
||||||
try:
|
|
||||||
if samefile(dir, os.curdir):
|
|
||||||
names.reverse()
|
|
||||||
return os.path.join(*names)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ____________________________________________________________
|
|
||||||
|
|
||||||
try:
|
|
||||||
int_or_long = (int, long)
|
|
||||||
import cStringIO
|
|
||||||
except NameError:
|
|
||||||
int_or_long = int # Python 3
|
|
||||||
import io as cStringIO
|
|
||||||
|
|
||||||
def _flatten(x, f):
|
|
||||||
if isinstance(x, str):
|
|
||||||
f.write('%ds%s' % (len(x), x))
|
|
||||||
elif isinstance(x, dict):
|
|
||||||
keys = sorted(x.keys())
|
|
||||||
f.write('%dd' % len(keys))
|
|
||||||
for key in keys:
|
|
||||||
_flatten(key, f)
|
|
||||||
_flatten(x[key], f)
|
|
||||||
elif isinstance(x, (list, tuple)):
|
|
||||||
f.write('%dl' % len(x))
|
|
||||||
for value in x:
|
|
||||||
_flatten(value, f)
|
|
||||||
elif isinstance(x, int_or_long):
|
|
||||||
f.write('%di' % (x,))
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
"the keywords to verify() contains unsupported object %r" % (x,))
|
|
||||||
|
|
||||||
def flatten(x):
|
|
||||||
f = cStringIO.StringIO()
|
|
||||||
_flatten(x, f)
|
|
||||||
return f.getvalue()
|
|
||||||
|
|
||||||
def _hack_at_distutils():
|
|
||||||
# Windows-only workaround for some configurations: see
|
|
||||||
# https://bugs.python.org/issue23246 (Python 2.7 with
|
|
||||||
# a specific MS compiler suite download)
|
|
||||||
if sys.platform == "win32":
|
|
||||||
try:
|
|
||||||
import setuptools # for side-effects, patches distutils
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
if sys.version_info < (3,):
|
|
||||||
try:
|
|
||||||
from thread import allocate_lock
|
|
||||||
except ImportError:
|
|
||||||
from dummy_thread import allocate_lock
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from _thread import allocate_lock
|
|
||||||
except ImportError:
|
|
||||||
from _dummy_thread import allocate_lock
|
|
||||||
|
|
||||||
|
|
||||||
##import sys
|
|
||||||
##l1 = allocate_lock
|
|
||||||
|
|
||||||
##class allocate_lock(object):
|
|
||||||
## def __init__(self):
|
|
||||||
## self._real = l1()
|
|
||||||
## def __enter__(self):
|
|
||||||
## for i in range(4, 0, -1):
|
|
||||||
## print sys._getframe(i).f_code
|
|
||||||
## print
|
|
||||||
## return self._real.__enter__()
|
|
||||||
## def __exit__(self, *args):
|
|
||||||
## return self._real.__exit__(*args)
|
|
||||||
## def acquire(self, f):
|
|
||||||
## assert f is False
|
|
||||||
## return self._real.acquire(f)
|
|
||||||
@@ -1,617 +0,0 @@
|
|||||||
import types
|
|
||||||
import weakref
|
|
||||||
|
|
||||||
from .lock import allocate_lock
|
|
||||||
from .error import CDefError, VerificationError, VerificationMissing
|
|
||||||
|
|
||||||
# type qualifiers
|
|
||||||
Q_CONST = 0x01
|
|
||||||
Q_RESTRICT = 0x02
|
|
||||||
Q_VOLATILE = 0x04
|
|
||||||
|
|
||||||
def qualify(quals, replace_with):
|
|
||||||
if quals & Q_CONST:
|
|
||||||
replace_with = ' const ' + replace_with.lstrip()
|
|
||||||
if quals & Q_VOLATILE:
|
|
||||||
replace_with = ' volatile ' + replace_with.lstrip()
|
|
||||||
if quals & Q_RESTRICT:
|
|
||||||
# It seems that __restrict is supported by gcc and msvc.
|
|
||||||
# If you hit some different compiler, add a #define in
|
|
||||||
# _cffi_include.h for it (and in its copies, documented there)
|
|
||||||
replace_with = ' __restrict ' + replace_with.lstrip()
|
|
||||||
return replace_with
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTypeByIdentity(object):
|
|
||||||
is_array_type = False
|
|
||||||
is_raw_function = False
|
|
||||||
|
|
||||||
def get_c_name(self, replace_with='', context='a C file', quals=0):
|
|
||||||
result = self.c_name_with_marker
|
|
||||||
assert result.count('&') == 1
|
|
||||||
# some logic duplication with ffi.getctype()... :-(
|
|
||||||
replace_with = replace_with.strip()
|
|
||||||
if replace_with:
|
|
||||||
if replace_with.startswith('*') and '&[' in result:
|
|
||||||
replace_with = '(%s)' % replace_with
|
|
||||||
elif not replace_with[0] in '[(':
|
|
||||||
replace_with = ' ' + replace_with
|
|
||||||
replace_with = qualify(quals, replace_with)
|
|
||||||
result = result.replace('&', replace_with)
|
|
||||||
if '$' in result:
|
|
||||||
raise VerificationError(
|
|
||||||
"cannot generate '%s' in %s: unknown type name"
|
|
||||||
% (self._get_c_name(), context))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _get_c_name(self):
|
|
||||||
return self.c_name_with_marker.replace('&', '')
|
|
||||||
|
|
||||||
def has_c_name(self):
|
|
||||||
return '$' not in self._get_c_name()
|
|
||||||
|
|
||||||
def is_integer_type(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_cached_btype(self, ffi, finishlist, can_delay=False):
|
|
||||||
try:
|
|
||||||
BType = ffi._cached_btypes[self]
|
|
||||||
except KeyError:
|
|
||||||
BType = self.build_backend_type(ffi, finishlist)
|
|
||||||
BType2 = ffi._cached_btypes.setdefault(self, BType)
|
|
||||||
assert BType2 is BType
|
|
||||||
return BType
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return '<%s>' % (self._get_c_name(),)
|
|
||||||
|
|
||||||
def _get_items(self):
|
|
||||||
return [(name, getattr(self, name)) for name in self._attrs_]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseType(BaseTypeByIdentity):
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return (self.__class__ == other.__class__ and
|
|
||||||
self._get_items() == other._get_items())
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
return not self == other
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash((self.__class__, tuple(self._get_items())))
|
|
||||||
|
|
||||||
|
|
||||||
class VoidType(BaseType):
|
|
||||||
_attrs_ = ()
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.c_name_with_marker = 'void&'
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
return global_cache(self, ffi, 'new_void_type')
|
|
||||||
|
|
||||||
void_type = VoidType()
|
|
||||||
|
|
||||||
|
|
||||||
class BasePrimitiveType(BaseType):
|
|
||||||
def is_complex_type(self):
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class PrimitiveType(BasePrimitiveType):
|
|
||||||
_attrs_ = ('name',)
|
|
||||||
|
|
||||||
ALL_PRIMITIVE_TYPES = {
|
|
||||||
'char': 'c',
|
|
||||||
'short': 'i',
|
|
||||||
'int': 'i',
|
|
||||||
'long': 'i',
|
|
||||||
'long long': 'i',
|
|
||||||
'signed char': 'i',
|
|
||||||
'unsigned char': 'i',
|
|
||||||
'unsigned short': 'i',
|
|
||||||
'unsigned int': 'i',
|
|
||||||
'unsigned long': 'i',
|
|
||||||
'unsigned long long': 'i',
|
|
||||||
'float': 'f',
|
|
||||||
'double': 'f',
|
|
||||||
'long double': 'f',
|
|
||||||
'float _Complex': 'j',
|
|
||||||
'double _Complex': 'j',
|
|
||||||
'_Bool': 'i',
|
|
||||||
# the following types are not primitive in the C sense
|
|
||||||
'wchar_t': 'c',
|
|
||||||
'char16_t': 'c',
|
|
||||||
'char32_t': 'c',
|
|
||||||
'int8_t': 'i',
|
|
||||||
'uint8_t': 'i',
|
|
||||||
'int16_t': 'i',
|
|
||||||
'uint16_t': 'i',
|
|
||||||
'int32_t': 'i',
|
|
||||||
'uint32_t': 'i',
|
|
||||||
'int64_t': 'i',
|
|
||||||
'uint64_t': 'i',
|
|
||||||
'int_least8_t': 'i',
|
|
||||||
'uint_least8_t': 'i',
|
|
||||||
'int_least16_t': 'i',
|
|
||||||
'uint_least16_t': 'i',
|
|
||||||
'int_least32_t': 'i',
|
|
||||||
'uint_least32_t': 'i',
|
|
||||||
'int_least64_t': 'i',
|
|
||||||
'uint_least64_t': 'i',
|
|
||||||
'int_fast8_t': 'i',
|
|
||||||
'uint_fast8_t': 'i',
|
|
||||||
'int_fast16_t': 'i',
|
|
||||||
'uint_fast16_t': 'i',
|
|
||||||
'int_fast32_t': 'i',
|
|
||||||
'uint_fast32_t': 'i',
|
|
||||||
'int_fast64_t': 'i',
|
|
||||||
'uint_fast64_t': 'i',
|
|
||||||
'intptr_t': 'i',
|
|
||||||
'uintptr_t': 'i',
|
|
||||||
'intmax_t': 'i',
|
|
||||||
'uintmax_t': 'i',
|
|
||||||
'ptrdiff_t': 'i',
|
|
||||||
'size_t': 'i',
|
|
||||||
'ssize_t': 'i',
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, name):
|
|
||||||
assert name in self.ALL_PRIMITIVE_TYPES
|
|
||||||
self.name = name
|
|
||||||
self.c_name_with_marker = name + '&'
|
|
||||||
|
|
||||||
def is_char_type(self):
|
|
||||||
return self.ALL_PRIMITIVE_TYPES[self.name] == 'c'
|
|
||||||
def is_integer_type(self):
|
|
||||||
return self.ALL_PRIMITIVE_TYPES[self.name] == 'i'
|
|
||||||
def is_float_type(self):
|
|
||||||
return self.ALL_PRIMITIVE_TYPES[self.name] == 'f'
|
|
||||||
def is_complex_type(self):
|
|
||||||
return self.ALL_PRIMITIVE_TYPES[self.name] == 'j'
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
return global_cache(self, ffi, 'new_primitive_type', self.name)
|
|
||||||
|
|
||||||
|
|
||||||
class UnknownIntegerType(BasePrimitiveType):
|
|
||||||
_attrs_ = ('name',)
|
|
||||||
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
self.c_name_with_marker = name + '&'
|
|
||||||
|
|
||||||
def is_integer_type(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
raise NotImplementedError("integer type '%s' can only be used after "
|
|
||||||
"compilation" % self.name)
|
|
||||||
|
|
||||||
class UnknownFloatType(BasePrimitiveType):
|
|
||||||
_attrs_ = ('name', )
|
|
||||||
|
|
||||||
def __init__(self, name):
|
|
||||||
self.name = name
|
|
||||||
self.c_name_with_marker = name + '&'
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
raise NotImplementedError("float type '%s' can only be used after "
|
|
||||||
"compilation" % self.name)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFunctionType(BaseType):
|
|
||||||
_attrs_ = ('args', 'result', 'ellipsis', 'abi')
|
|
||||||
|
|
||||||
def __init__(self, args, result, ellipsis, abi=None):
|
|
||||||
self.args = args
|
|
||||||
self.result = result
|
|
||||||
self.ellipsis = ellipsis
|
|
||||||
self.abi = abi
|
|
||||||
#
|
|
||||||
reprargs = [arg._get_c_name() for arg in self.args]
|
|
||||||
if self.ellipsis:
|
|
||||||
reprargs.append('...')
|
|
||||||
reprargs = reprargs or ['void']
|
|
||||||
replace_with = self._base_pattern % (', '.join(reprargs),)
|
|
||||||
if abi is not None:
|
|
||||||
replace_with = replace_with[:1] + abi + ' ' + replace_with[1:]
|
|
||||||
self.c_name_with_marker = (
|
|
||||||
self.result.c_name_with_marker.replace('&', replace_with))
|
|
||||||
|
|
||||||
|
|
||||||
class RawFunctionType(BaseFunctionType):
|
|
||||||
# Corresponds to a C type like 'int(int)', which is the C type of
|
|
||||||
# a function, but not a pointer-to-function. The backend has no
|
|
||||||
# notion of such a type; it's used temporarily by parsing.
|
|
||||||
_base_pattern = '(&)(%s)'
|
|
||||||
is_raw_function = True
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
raise CDefError("cannot render the type %r: it is a function "
|
|
||||||
"type, not a pointer-to-function type" % (self,))
|
|
||||||
|
|
||||||
def as_function_pointer(self):
|
|
||||||
return FunctionPtrType(self.args, self.result, self.ellipsis, self.abi)
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionPtrType(BaseFunctionType):
|
|
||||||
_base_pattern = '(*&)(%s)'
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
result = self.result.get_cached_btype(ffi, finishlist)
|
|
||||||
args = []
|
|
||||||
for tp in self.args:
|
|
||||||
args.append(tp.get_cached_btype(ffi, finishlist))
|
|
||||||
abi_args = ()
|
|
||||||
if self.abi == "__stdcall":
|
|
||||||
if not self.ellipsis: # __stdcall ignored for variadic funcs
|
|
||||||
try:
|
|
||||||
abi_args = (ffi._backend.FFI_STDCALL,)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
return global_cache(self, ffi, 'new_function_type',
|
|
||||||
tuple(args), result, self.ellipsis, *abi_args)
|
|
||||||
|
|
||||||
def as_raw_function(self):
|
|
||||||
return RawFunctionType(self.args, self.result, self.ellipsis, self.abi)
|
|
||||||
|
|
||||||
|
|
||||||
class PointerType(BaseType):
|
|
||||||
_attrs_ = ('totype', 'quals')
|
|
||||||
|
|
||||||
def __init__(self, totype, quals=0):
|
|
||||||
self.totype = totype
|
|
||||||
self.quals = quals
|
|
||||||
extra = qualify(quals, " *&")
|
|
||||||
if totype.is_array_type:
|
|
||||||
extra = "(%s)" % (extra.lstrip(),)
|
|
||||||
self.c_name_with_marker = totype.c_name_with_marker.replace('&', extra)
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
BItem = self.totype.get_cached_btype(ffi, finishlist, can_delay=True)
|
|
||||||
return global_cache(self, ffi, 'new_pointer_type', BItem)
|
|
||||||
|
|
||||||
voidp_type = PointerType(void_type)
|
|
||||||
|
|
||||||
def ConstPointerType(totype):
|
|
||||||
return PointerType(totype, Q_CONST)
|
|
||||||
|
|
||||||
const_voidp_type = ConstPointerType(void_type)
|
|
||||||
|
|
||||||
|
|
||||||
class NamedPointerType(PointerType):
|
|
||||||
_attrs_ = ('totype', 'name')
|
|
||||||
|
|
||||||
def __init__(self, totype, name, quals=0):
|
|
||||||
PointerType.__init__(self, totype, quals)
|
|
||||||
self.name = name
|
|
||||||
self.c_name_with_marker = name + '&'
|
|
||||||
|
|
||||||
|
|
||||||
class ArrayType(BaseType):
|
|
||||||
_attrs_ = ('item', 'length')
|
|
||||||
is_array_type = True
|
|
||||||
|
|
||||||
def __init__(self, item, length):
|
|
||||||
self.item = item
|
|
||||||
self.length = length
|
|
||||||
#
|
|
||||||
if length is None:
|
|
||||||
brackets = '&[]'
|
|
||||||
elif length == '...':
|
|
||||||
brackets = '&[/*...*/]'
|
|
||||||
else:
|
|
||||||
brackets = '&[%s]' % length
|
|
||||||
self.c_name_with_marker = (
|
|
||||||
self.item.c_name_with_marker.replace('&', brackets))
|
|
||||||
|
|
||||||
def length_is_unknown(self):
|
|
||||||
return isinstance(self.length, str)
|
|
||||||
|
|
||||||
def resolve_length(self, newlength):
|
|
||||||
return ArrayType(self.item, newlength)
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
if self.length_is_unknown():
|
|
||||||
raise CDefError("cannot render the type %r: unknown length" %
|
|
||||||
(self,))
|
|
||||||
self.item.get_cached_btype(ffi, finishlist) # force the item BType
|
|
||||||
BPtrItem = PointerType(self.item).get_cached_btype(ffi, finishlist)
|
|
||||||
return global_cache(self, ffi, 'new_array_type', BPtrItem, self.length)
|
|
||||||
|
|
||||||
char_array_type = ArrayType(PrimitiveType('char'), None)
|
|
||||||
|
|
||||||
|
|
||||||
class StructOrUnionOrEnum(BaseTypeByIdentity):
|
|
||||||
_attrs_ = ('name',)
|
|
||||||
forcename = None
|
|
||||||
|
|
||||||
def build_c_name_with_marker(self):
|
|
||||||
name = self.forcename or '%s %s' % (self.kind, self.name)
|
|
||||||
self.c_name_with_marker = name + '&'
|
|
||||||
|
|
||||||
def force_the_name(self, forcename):
|
|
||||||
self.forcename = forcename
|
|
||||||
self.build_c_name_with_marker()
|
|
||||||
|
|
||||||
def get_official_name(self):
|
|
||||||
assert self.c_name_with_marker.endswith('&')
|
|
||||||
return self.c_name_with_marker[:-1]
|
|
||||||
|
|
||||||
|
|
||||||
class StructOrUnion(StructOrUnionOrEnum):
|
|
||||||
fixedlayout = None
|
|
||||||
completed = 0
|
|
||||||
partial = False
|
|
||||||
packed = 0
|
|
||||||
|
|
||||||
def __init__(self, name, fldnames, fldtypes, fldbitsize, fldquals=None):
|
|
||||||
self.name = name
|
|
||||||
self.fldnames = fldnames
|
|
||||||
self.fldtypes = fldtypes
|
|
||||||
self.fldbitsize = fldbitsize
|
|
||||||
self.fldquals = fldquals
|
|
||||||
self.build_c_name_with_marker()
|
|
||||||
|
|
||||||
def anonymous_struct_fields(self):
|
|
||||||
if self.fldtypes is not None:
|
|
||||||
for name, type in zip(self.fldnames, self.fldtypes):
|
|
||||||
if name == '' and isinstance(type, StructOrUnion):
|
|
||||||
yield type
|
|
||||||
|
|
||||||
def enumfields(self, expand_anonymous_struct_union=True):
|
|
||||||
fldquals = self.fldquals
|
|
||||||
if fldquals is None:
|
|
||||||
fldquals = (0,) * len(self.fldnames)
|
|
||||||
for name, type, bitsize, quals in zip(self.fldnames, self.fldtypes,
|
|
||||||
self.fldbitsize, fldquals):
|
|
||||||
if (name == '' and isinstance(type, StructOrUnion)
|
|
||||||
and expand_anonymous_struct_union):
|
|
||||||
# nested anonymous struct/union
|
|
||||||
for result in type.enumfields():
|
|
||||||
yield result
|
|
||||||
else:
|
|
||||||
yield (name, type, bitsize, quals)
|
|
||||||
|
|
||||||
def force_flatten(self):
|
|
||||||
# force the struct or union to have a declaration that lists
|
|
||||||
# directly all fields returned by enumfields(), flattening
|
|
||||||
# nested anonymous structs/unions.
|
|
||||||
names = []
|
|
||||||
types = []
|
|
||||||
bitsizes = []
|
|
||||||
fldquals = []
|
|
||||||
for name, type, bitsize, quals in self.enumfields():
|
|
||||||
names.append(name)
|
|
||||||
types.append(type)
|
|
||||||
bitsizes.append(bitsize)
|
|
||||||
fldquals.append(quals)
|
|
||||||
self.fldnames = tuple(names)
|
|
||||||
self.fldtypes = tuple(types)
|
|
||||||
self.fldbitsize = tuple(bitsizes)
|
|
||||||
self.fldquals = tuple(fldquals)
|
|
||||||
|
|
||||||
def get_cached_btype(self, ffi, finishlist, can_delay=False):
|
|
||||||
BType = StructOrUnionOrEnum.get_cached_btype(self, ffi, finishlist,
|
|
||||||
can_delay)
|
|
||||||
if not can_delay:
|
|
||||||
self.finish_backend_type(ffi, finishlist)
|
|
||||||
return BType
|
|
||||||
|
|
||||||
def finish_backend_type(self, ffi, finishlist):
|
|
||||||
if self.completed:
|
|
||||||
if self.completed != 2:
|
|
||||||
raise NotImplementedError("recursive structure declaration "
|
|
||||||
"for '%s'" % (self.name,))
|
|
||||||
return
|
|
||||||
BType = ffi._cached_btypes[self]
|
|
||||||
#
|
|
||||||
self.completed = 1
|
|
||||||
#
|
|
||||||
if self.fldtypes is None:
|
|
||||||
pass # not completing it: it's an opaque struct
|
|
||||||
#
|
|
||||||
elif self.fixedlayout is None:
|
|
||||||
fldtypes = [tp.get_cached_btype(ffi, finishlist)
|
|
||||||
for tp in self.fldtypes]
|
|
||||||
lst = list(zip(self.fldnames, fldtypes, self.fldbitsize))
|
|
||||||
extra_flags = ()
|
|
||||||
if self.packed:
|
|
||||||
if self.packed == 1:
|
|
||||||
extra_flags = (8,) # SF_PACKED
|
|
||||||
else:
|
|
||||||
extra_flags = (0, self.packed)
|
|
||||||
ffi._backend.complete_struct_or_union(BType, lst, self,
|
|
||||||
-1, -1, *extra_flags)
|
|
||||||
#
|
|
||||||
else:
|
|
||||||
fldtypes = []
|
|
||||||
fieldofs, fieldsize, totalsize, totalalignment = self.fixedlayout
|
|
||||||
for i in range(len(self.fldnames)):
|
|
||||||
fsize = fieldsize[i]
|
|
||||||
ftype = self.fldtypes[i]
|
|
||||||
#
|
|
||||||
if isinstance(ftype, ArrayType) and ftype.length_is_unknown():
|
|
||||||
# fix the length to match the total size
|
|
||||||
BItemType = ftype.item.get_cached_btype(ffi, finishlist)
|
|
||||||
nlen, nrest = divmod(fsize, ffi.sizeof(BItemType))
|
|
||||||
if nrest != 0:
|
|
||||||
self._verification_error(
|
|
||||||
"field '%s.%s' has a bogus size?" % (
|
|
||||||
self.name, self.fldnames[i] or '{}'))
|
|
||||||
ftype = ftype.resolve_length(nlen)
|
|
||||||
self.fldtypes = (self.fldtypes[:i] + (ftype,) +
|
|
||||||
self.fldtypes[i+1:])
|
|
||||||
#
|
|
||||||
BFieldType = ftype.get_cached_btype(ffi, finishlist)
|
|
||||||
if isinstance(ftype, ArrayType) and ftype.length is None:
|
|
||||||
assert fsize == 0
|
|
||||||
else:
|
|
||||||
bitemsize = ffi.sizeof(BFieldType)
|
|
||||||
if bitemsize != fsize:
|
|
||||||
self._verification_error(
|
|
||||||
"field '%s.%s' is declared as %d bytes, but is "
|
|
||||||
"really %d bytes" % (self.name,
|
|
||||||
self.fldnames[i] or '{}',
|
|
||||||
bitemsize, fsize))
|
|
||||||
fldtypes.append(BFieldType)
|
|
||||||
#
|
|
||||||
lst = list(zip(self.fldnames, fldtypes, self.fldbitsize, fieldofs))
|
|
||||||
ffi._backend.complete_struct_or_union(BType, lst, self,
|
|
||||||
totalsize, totalalignment)
|
|
||||||
self.completed = 2
|
|
||||||
|
|
||||||
def _verification_error(self, msg):
|
|
||||||
raise VerificationError(msg)
|
|
||||||
|
|
||||||
def check_not_partial(self):
|
|
||||||
if self.partial and self.fixedlayout is None:
|
|
||||||
raise VerificationMissing(self._get_c_name())
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
self.check_not_partial()
|
|
||||||
finishlist.append(self)
|
|
||||||
#
|
|
||||||
return global_cache(self, ffi, 'new_%s_type' % self.kind,
|
|
||||||
self.get_official_name(), key=self)
|
|
||||||
|
|
||||||
|
|
||||||
class StructType(StructOrUnion):
|
|
||||||
kind = 'struct'
|
|
||||||
|
|
||||||
|
|
||||||
class UnionType(StructOrUnion):
|
|
||||||
kind = 'union'
|
|
||||||
|
|
||||||
|
|
||||||
class EnumType(StructOrUnionOrEnum):
|
|
||||||
kind = 'enum'
|
|
||||||
partial = False
|
|
||||||
partial_resolved = False
|
|
||||||
|
|
||||||
def __init__(self, name, enumerators, enumvalues, baseinttype=None):
|
|
||||||
self.name = name
|
|
||||||
self.enumerators = enumerators
|
|
||||||
self.enumvalues = enumvalues
|
|
||||||
self.baseinttype = baseinttype
|
|
||||||
self.build_c_name_with_marker()
|
|
||||||
|
|
||||||
def force_the_name(self, forcename):
|
|
||||||
StructOrUnionOrEnum.force_the_name(self, forcename)
|
|
||||||
if self.forcename is None:
|
|
||||||
name = self.get_official_name()
|
|
||||||
self.forcename = '$' + name.replace(' ', '_')
|
|
||||||
|
|
||||||
def check_not_partial(self):
|
|
||||||
if self.partial and not self.partial_resolved:
|
|
||||||
raise VerificationMissing(self._get_c_name())
|
|
||||||
|
|
||||||
def build_backend_type(self, ffi, finishlist):
|
|
||||||
self.check_not_partial()
|
|
||||||
base_btype = self.build_baseinttype(ffi, finishlist)
|
|
||||||
return global_cache(self, ffi, 'new_enum_type',
|
|
||||||
self.get_official_name(),
|
|
||||||
self.enumerators, self.enumvalues,
|
|
||||||
base_btype, key=self)
|
|
||||||
|
|
||||||
def build_baseinttype(self, ffi, finishlist):
|
|
||||||
if self.baseinttype is not None:
|
|
||||||
return self.baseinttype.get_cached_btype(ffi, finishlist)
|
|
||||||
#
|
|
||||||
if self.enumvalues:
|
|
||||||
smallest_value = min(self.enumvalues)
|
|
||||||
largest_value = max(self.enumvalues)
|
|
||||||
else:
|
|
||||||
import warnings
|
|
||||||
try:
|
|
||||||
# XXX! The goal is to ensure that the warnings.warn()
|
|
||||||
# will not suppress the warning. We want to get it
|
|
||||||
# several times if we reach this point several times.
|
|
||||||
__warningregistry__.clear()
|
|
||||||
except NameError:
|
|
||||||
pass
|
|
||||||
warnings.warn("%r has no values explicitly defined; "
|
|
||||||
"guessing that it is equivalent to 'unsigned int'"
|
|
||||||
% self._get_c_name())
|
|
||||||
smallest_value = largest_value = 0
|
|
||||||
if smallest_value < 0: # needs a signed type
|
|
||||||
sign = 1
|
|
||||||
candidate1 = PrimitiveType("int")
|
|
||||||
candidate2 = PrimitiveType("long")
|
|
||||||
else:
|
|
||||||
sign = 0
|
|
||||||
candidate1 = PrimitiveType("unsigned int")
|
|
||||||
candidate2 = PrimitiveType("unsigned long")
|
|
||||||
btype1 = candidate1.get_cached_btype(ffi, finishlist)
|
|
||||||
btype2 = candidate2.get_cached_btype(ffi, finishlist)
|
|
||||||
size1 = ffi.sizeof(btype1)
|
|
||||||
size2 = ffi.sizeof(btype2)
|
|
||||||
if (smallest_value >= ((-1) << (8*size1-1)) and
|
|
||||||
largest_value < (1 << (8*size1-sign))):
|
|
||||||
return btype1
|
|
||||||
if (smallest_value >= ((-1) << (8*size2-1)) and
|
|
||||||
largest_value < (1 << (8*size2-sign))):
|
|
||||||
return btype2
|
|
||||||
raise CDefError("%s values don't all fit into either 'long' "
|
|
||||||
"or 'unsigned long'" % self._get_c_name())
|
|
||||||
|
|
||||||
def unknown_type(name, structname=None):
|
|
||||||
if structname is None:
|
|
||||||
structname = '$%s' % name
|
|
||||||
tp = StructType(structname, None, None, None)
|
|
||||||
tp.force_the_name(name)
|
|
||||||
tp.origin = "unknown_type"
|
|
||||||
return tp
|
|
||||||
|
|
||||||
def unknown_ptr_type(name, structname=None):
|
|
||||||
if structname is None:
|
|
||||||
structname = '$$%s' % name
|
|
||||||
tp = StructType(structname, None, None, None)
|
|
||||||
return NamedPointerType(tp, name)
|
|
||||||
|
|
||||||
|
|
||||||
global_lock = allocate_lock()
|
|
||||||
_typecache_cffi_backend = weakref.WeakValueDictionary()
|
|
||||||
|
|
||||||
def get_typecache(backend):
|
|
||||||
# returns _typecache_cffi_backend if backend is the _cffi_backend
|
|
||||||
# module, or type(backend).__typecache if backend is an instance of
|
|
||||||
# CTypesBackend (or some FakeBackend class during tests)
|
|
||||||
if isinstance(backend, types.ModuleType):
|
|
||||||
return _typecache_cffi_backend
|
|
||||||
with global_lock:
|
|
||||||
if not hasattr(type(backend), '__typecache'):
|
|
||||||
type(backend).__typecache = weakref.WeakValueDictionary()
|
|
||||||
return type(backend).__typecache
|
|
||||||
|
|
||||||
def global_cache(srctype, ffi, funcname, *args, **kwds):
|
|
||||||
key = kwds.pop('key', (funcname, args))
|
|
||||||
assert not kwds
|
|
||||||
try:
|
|
||||||
return ffi._typecache[key]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
res = getattr(ffi._backend, funcname)(*args)
|
|
||||||
except NotImplementedError as e:
|
|
||||||
raise NotImplementedError("%s: %r: %s" % (funcname, srctype, e))
|
|
||||||
# note that setdefault() on WeakValueDictionary is not atomic
|
|
||||||
# and contains a rare bug (http://bugs.python.org/issue19542);
|
|
||||||
# we have to use a lock and do it ourselves
|
|
||||||
cache = ffi._typecache
|
|
||||||
with global_lock:
|
|
||||||
res1 = cache.get(key)
|
|
||||||
if res1 is None:
|
|
||||||
cache[key] = res
|
|
||||||
return res
|
|
||||||
else:
|
|
||||||
return res1
|
|
||||||
|
|
||||||
def pointer_cache(ffi, BType):
|
|
||||||
return global_cache('?', ffi, 'new_pointer_type', BType)
|
|
||||||
|
|
||||||
def attach_exception_info(e, name):
|
|
||||||
if e.args and type(e.args[0]) is str:
|
|
||||||
e.args = ('%s: %s' % (name, e.args[0]),) + e.args[1:]
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
# pkg-config, https://www.freedesktop.org/wiki/Software/pkg-config/ integration for cffi
|
|
||||||
import sys, os, subprocess
|
|
||||||
|
|
||||||
from .error import PkgConfigError
|
|
||||||
|
|
||||||
|
|
||||||
def merge_flags(cfg1, cfg2):
|
|
||||||
"""Merge values from cffi config flags cfg2 to cf1
|
|
||||||
|
|
||||||
Example:
|
|
||||||
merge_flags({"libraries": ["one"]}, {"libraries": ["two"]})
|
|
||||||
{"libraries": ["one", "two"]}
|
|
||||||
"""
|
|
||||||
for key, value in cfg2.items():
|
|
||||||
if key not in cfg1:
|
|
||||||
cfg1[key] = value
|
|
||||||
else:
|
|
||||||
if not isinstance(cfg1[key], list):
|
|
||||||
raise TypeError("cfg1[%r] should be a list of strings" % (key,))
|
|
||||||
if not isinstance(value, list):
|
|
||||||
raise TypeError("cfg2[%r] should be a list of strings" % (key,))
|
|
||||||
cfg1[key].extend(value)
|
|
||||||
return cfg1
|
|
||||||
|
|
||||||
|
|
||||||
def call(libname, flag, encoding=sys.getfilesystemencoding()):
|
|
||||||
"""Calls pkg-config and returns the output if found
|
|
||||||
"""
|
|
||||||
a = ["pkg-config", "--print-errors"]
|
|
||||||
a.append(flag)
|
|
||||||
a.append(libname)
|
|
||||||
try:
|
|
||||||
pc = subprocess.Popen(a, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
||||||
except EnvironmentError as e:
|
|
||||||
raise PkgConfigError("cannot run pkg-config: %s" % (str(e).strip(),))
|
|
||||||
|
|
||||||
bout, berr = pc.communicate()
|
|
||||||
if pc.returncode != 0:
|
|
||||||
try:
|
|
||||||
berr = berr.decode(encoding)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise PkgConfigError(berr.strip())
|
|
||||||
|
|
||||||
if sys.version_info >= (3,) and not isinstance(bout, str): # Python 3.x
|
|
||||||
try:
|
|
||||||
bout = bout.decode(encoding)
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
raise PkgConfigError("pkg-config %s %s returned bytes that cannot "
|
|
||||||
"be decoded with encoding %r:\n%r" %
|
|
||||||
(flag, libname, encoding, bout))
|
|
||||||
|
|
||||||
if os.altsep != '\\' and '\\' in bout:
|
|
||||||
raise PkgConfigError("pkg-config %s %s returned an unsupported "
|
|
||||||
"backslash-escaped output:\n%r" %
|
|
||||||
(flag, libname, bout))
|
|
||||||
return bout
|
|
||||||
|
|
||||||
|
|
||||||
def flags_from_pkgconfig(libs):
|
|
||||||
r"""Return compiler line flags for FFI.set_source based on pkg-config output
|
|
||||||
|
|
||||||
Usage
|
|
||||||
...
|
|
||||||
ffibuilder.set_source("_foo", pkgconfig = ["libfoo", "libbar >= 1.8.3"])
|
|
||||||
|
|
||||||
If pkg-config is installed on build machine, then arguments include_dirs,
|
|
||||||
library_dirs, libraries, define_macros, extra_compile_args and
|
|
||||||
extra_link_args are extended with an output of pkg-config for libfoo and
|
|
||||||
libbar.
|
|
||||||
|
|
||||||
Raises PkgConfigError in case the pkg-config call fails.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_include_dirs(string):
|
|
||||||
return [x[2:] for x in string.split() if x.startswith("-I")]
|
|
||||||
|
|
||||||
def get_library_dirs(string):
|
|
||||||
return [x[2:] for x in string.split() if x.startswith("-L")]
|
|
||||||
|
|
||||||
def get_libraries(string):
|
|
||||||
return [x[2:] for x in string.split() if x.startswith("-l")]
|
|
||||||
|
|
||||||
# convert -Dfoo=bar to list of tuples [("foo", "bar")] expected by distutils
|
|
||||||
def get_macros(string):
|
|
||||||
def _macro(x):
|
|
||||||
x = x[2:] # drop "-D"
|
|
||||||
if '=' in x:
|
|
||||||
return tuple(x.split("=", 1)) # "-Dfoo=bar" => ("foo", "bar")
|
|
||||||
else:
|
|
||||||
return (x, None) # "-Dfoo" => ("foo", None)
|
|
||||||
return [_macro(x) for x in string.split() if x.startswith("-D")]
|
|
||||||
|
|
||||||
def get_other_cflags(string):
|
|
||||||
return [x for x in string.split() if not x.startswith("-I") and
|
|
||||||
not x.startswith("-D")]
|
|
||||||
|
|
||||||
def get_other_libs(string):
|
|
||||||
return [x for x in string.split() if not x.startswith("-L") and
|
|
||||||
not x.startswith("-l")]
|
|
||||||
|
|
||||||
# return kwargs for given libname
|
|
||||||
def kwargs(libname):
|
|
||||||
fse = sys.getfilesystemencoding()
|
|
||||||
all_cflags = call(libname, "--cflags")
|
|
||||||
all_libs = call(libname, "--libs")
|
|
||||||
return {
|
|
||||||
"include_dirs": get_include_dirs(all_cflags),
|
|
||||||
"library_dirs": get_library_dirs(all_libs),
|
|
||||||
"libraries": get_libraries(all_libs),
|
|
||||||
"define_macros": get_macros(all_cflags),
|
|
||||||
"extra_compile_args": get_other_cflags(all_cflags),
|
|
||||||
"extra_link_args": get_other_libs(all_libs),
|
|
||||||
}
|
|
||||||
|
|
||||||
# merge all arguments together
|
|
||||||
ret = {}
|
|
||||||
for libname in libs:
|
|
||||||
lib_flags = kwargs(libname)
|
|
||||||
merge_flags(ret, lib_flags)
|
|
||||||
return ret
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,219 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
try:
|
|
||||||
basestring
|
|
||||||
except NameError:
|
|
||||||
# Python 3.x
|
|
||||||
basestring = str
|
|
||||||
|
|
||||||
def error(msg):
|
|
||||||
from distutils.errors import DistutilsSetupError
|
|
||||||
raise DistutilsSetupError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def execfile(filename, glob):
|
|
||||||
# We use execfile() (here rewritten for Python 3) instead of
|
|
||||||
# __import__() to load the build script. The problem with
|
|
||||||
# a normal import is that in some packages, the intermediate
|
|
||||||
# __init__.py files may already try to import the file that
|
|
||||||
# we are generating.
|
|
||||||
with open(filename) as f:
|
|
||||||
src = f.read()
|
|
||||||
src += '\n' # Python 2.6 compatibility
|
|
||||||
code = compile(src, filename, 'exec')
|
|
||||||
exec(code, glob, glob)
|
|
||||||
|
|
||||||
|
|
||||||
def add_cffi_module(dist, mod_spec):
|
|
||||||
from cffi.api import FFI
|
|
||||||
|
|
||||||
if not isinstance(mod_spec, basestring):
|
|
||||||
error("argument to 'cffi_modules=...' must be a str or a list of str,"
|
|
||||||
" not %r" % (type(mod_spec).__name__,))
|
|
||||||
mod_spec = str(mod_spec)
|
|
||||||
try:
|
|
||||||
build_file_name, ffi_var_name = mod_spec.split(':')
|
|
||||||
except ValueError:
|
|
||||||
error("%r must be of the form 'path/build.py:ffi_variable'" %
|
|
||||||
(mod_spec,))
|
|
||||||
if not os.path.exists(build_file_name):
|
|
||||||
ext = ''
|
|
||||||
rewritten = build_file_name.replace('.', '/') + '.py'
|
|
||||||
if os.path.exists(rewritten):
|
|
||||||
ext = ' (rewrite cffi_modules to [%r])' % (
|
|
||||||
rewritten + ':' + ffi_var_name,)
|
|
||||||
error("%r does not name an existing file%s" % (build_file_name, ext))
|
|
||||||
|
|
||||||
mod_vars = {'__name__': '__cffi__', '__file__': build_file_name}
|
|
||||||
execfile(build_file_name, mod_vars)
|
|
||||||
|
|
||||||
try:
|
|
||||||
ffi = mod_vars[ffi_var_name]
|
|
||||||
except KeyError:
|
|
||||||
error("%r: object %r not found in module" % (mod_spec,
|
|
||||||
ffi_var_name))
|
|
||||||
if not isinstance(ffi, FFI):
|
|
||||||
ffi = ffi() # maybe it's a function instead of directly an ffi
|
|
||||||
if not isinstance(ffi, FFI):
|
|
||||||
error("%r is not an FFI instance (got %r)" % (mod_spec,
|
|
||||||
type(ffi).__name__))
|
|
||||||
if not hasattr(ffi, '_assigned_source'):
|
|
||||||
error("%r: the set_source() method was not called" % (mod_spec,))
|
|
||||||
module_name, source, source_extension, kwds = ffi._assigned_source
|
|
||||||
if ffi._windows_unicode:
|
|
||||||
kwds = kwds.copy()
|
|
||||||
ffi._apply_windows_unicode(kwds)
|
|
||||||
|
|
||||||
if source is None:
|
|
||||||
_add_py_module(dist, ffi, module_name)
|
|
||||||
else:
|
|
||||||
_add_c_module(dist, ffi, module_name, source, source_extension, kwds)
|
|
||||||
|
|
||||||
def _set_py_limited_api(Extension, kwds):
|
|
||||||
"""
|
|
||||||
Add py_limited_api to kwds if setuptools >= 26 is in use.
|
|
||||||
Do not alter the setting if it already exists.
|
|
||||||
Setuptools takes care of ignoring the flag on Python 2 and PyPy.
|
|
||||||
|
|
||||||
CPython itself should ignore the flag in a debugging version
|
|
||||||
(by not listing .abi3.so in the extensions it supports), but
|
|
||||||
it doesn't so far, creating troubles. That's why we check
|
|
||||||
for "not hasattr(sys, 'gettotalrefcount')" (the 2.7 compatible equivalent
|
|
||||||
of 'd' not in sys.abiflags). (http://bugs.python.org/issue28401)
|
|
||||||
|
|
||||||
On Windows, with CPython <= 3.4, it's better not to use py_limited_api
|
|
||||||
because virtualenv *still* doesn't copy PYTHON3.DLL on these versions.
|
|
||||||
Recently (2020) we started shipping only >= 3.5 wheels, though. So
|
|
||||||
we'll give it another try and set py_limited_api on Windows >= 3.5.
|
|
||||||
"""
|
|
||||||
from cffi import recompiler
|
|
||||||
|
|
||||||
if ('py_limited_api' not in kwds and not hasattr(sys, 'gettotalrefcount')
|
|
||||||
and recompiler.USE_LIMITED_API):
|
|
||||||
import setuptools
|
|
||||||
try:
|
|
||||||
setuptools_major_version = int(setuptools.__version__.partition('.')[0])
|
|
||||||
if setuptools_major_version >= 26:
|
|
||||||
kwds['py_limited_api'] = True
|
|
||||||
except ValueError: # certain development versions of setuptools
|
|
||||||
# If we don't know the version number of setuptools, we
|
|
||||||
# try to set 'py_limited_api' anyway. At worst, we get a
|
|
||||||
# warning.
|
|
||||||
kwds['py_limited_api'] = True
|
|
||||||
return kwds
|
|
||||||
|
|
||||||
def _add_c_module(dist, ffi, module_name, source, source_extension, kwds):
|
|
||||||
from distutils.core import Extension
|
|
||||||
# We are a setuptools extension. Need this build_ext for py_limited_api.
|
|
||||||
from setuptools.command.build_ext import build_ext
|
|
||||||
from distutils.dir_util import mkpath
|
|
||||||
from distutils import log
|
|
||||||
from cffi import recompiler
|
|
||||||
|
|
||||||
allsources = ['$PLACEHOLDER']
|
|
||||||
allsources.extend(kwds.pop('sources', []))
|
|
||||||
kwds = _set_py_limited_api(Extension, kwds)
|
|
||||||
ext = Extension(name=module_name, sources=allsources, **kwds)
|
|
||||||
|
|
||||||
def make_mod(tmpdir, pre_run=None):
|
|
||||||
c_file = os.path.join(tmpdir, module_name + source_extension)
|
|
||||||
log.info("generating cffi module %r" % c_file)
|
|
||||||
mkpath(tmpdir)
|
|
||||||
# a setuptools-only, API-only hook: called with the "ext" and "ffi"
|
|
||||||
# arguments just before we turn the ffi into C code. To use it,
|
|
||||||
# subclass the 'distutils.command.build_ext.build_ext' class and
|
|
||||||
# add a method 'def pre_run(self, ext, ffi)'.
|
|
||||||
if pre_run is not None:
|
|
||||||
pre_run(ext, ffi)
|
|
||||||
updated = recompiler.make_c_source(ffi, module_name, source, c_file)
|
|
||||||
if not updated:
|
|
||||||
log.info("already up-to-date")
|
|
||||||
return c_file
|
|
||||||
|
|
||||||
if dist.ext_modules is None:
|
|
||||||
dist.ext_modules = []
|
|
||||||
dist.ext_modules.append(ext)
|
|
||||||
|
|
||||||
base_class = dist.cmdclass.get('build_ext', build_ext)
|
|
||||||
class build_ext_make_mod(base_class):
|
|
||||||
def run(self):
|
|
||||||
if ext.sources[0] == '$PLACEHOLDER':
|
|
||||||
pre_run = getattr(self, 'pre_run', None)
|
|
||||||
ext.sources[0] = make_mod(self.build_temp, pre_run)
|
|
||||||
base_class.run(self)
|
|
||||||
dist.cmdclass['build_ext'] = build_ext_make_mod
|
|
||||||
# NB. multiple runs here will create multiple 'build_ext_make_mod'
|
|
||||||
# classes. Even in this case the 'build_ext' command should be
|
|
||||||
# run once; but just in case, the logic above does nothing if
|
|
||||||
# called again.
|
|
||||||
|
|
||||||
|
|
||||||
def _add_py_module(dist, ffi, module_name):
|
|
||||||
from distutils.dir_util import mkpath
|
|
||||||
from setuptools.command.build_py import build_py
|
|
||||||
from setuptools.command.build_ext import build_ext
|
|
||||||
from distutils import log
|
|
||||||
from cffi import recompiler
|
|
||||||
|
|
||||||
def generate_mod(py_file):
|
|
||||||
log.info("generating cffi module %r" % py_file)
|
|
||||||
mkpath(os.path.dirname(py_file))
|
|
||||||
updated = recompiler.make_py_source(ffi, module_name, py_file)
|
|
||||||
if not updated:
|
|
||||||
log.info("already up-to-date")
|
|
||||||
|
|
||||||
base_class = dist.cmdclass.get('build_py', build_py)
|
|
||||||
class build_py_make_mod(base_class):
|
|
||||||
def run(self):
|
|
||||||
base_class.run(self)
|
|
||||||
module_path = module_name.split('.')
|
|
||||||
module_path[-1] += '.py'
|
|
||||||
generate_mod(os.path.join(self.build_lib, *module_path))
|
|
||||||
def get_source_files(self):
|
|
||||||
# This is called from 'setup.py sdist' only. Exclude
|
|
||||||
# the generate .py module in this case.
|
|
||||||
saved_py_modules = self.py_modules
|
|
||||||
try:
|
|
||||||
if saved_py_modules:
|
|
||||||
self.py_modules = [m for m in saved_py_modules
|
|
||||||
if m != module_name]
|
|
||||||
return base_class.get_source_files(self)
|
|
||||||
finally:
|
|
||||||
self.py_modules = saved_py_modules
|
|
||||||
dist.cmdclass['build_py'] = build_py_make_mod
|
|
||||||
|
|
||||||
# distutils and setuptools have no notion I could find of a
|
|
||||||
# generated python module. If we don't add module_name to
|
|
||||||
# dist.py_modules, then things mostly work but there are some
|
|
||||||
# combination of options (--root and --record) that will miss
|
|
||||||
# the module. So we add it here, which gives a few apparently
|
|
||||||
# harmless warnings about not finding the file outside the
|
|
||||||
# build directory.
|
|
||||||
# Then we need to hack more in get_source_files(); see above.
|
|
||||||
if dist.py_modules is None:
|
|
||||||
dist.py_modules = []
|
|
||||||
dist.py_modules.append(module_name)
|
|
||||||
|
|
||||||
# the following is only for "build_ext -i"
|
|
||||||
base_class_2 = dist.cmdclass.get('build_ext', build_ext)
|
|
||||||
class build_ext_make_mod(base_class_2):
|
|
||||||
def run(self):
|
|
||||||
base_class_2.run(self)
|
|
||||||
if self.inplace:
|
|
||||||
# from get_ext_fullpath() in distutils/command/build_ext.py
|
|
||||||
module_path = module_name.split('.')
|
|
||||||
package = '.'.join(module_path[:-1])
|
|
||||||
build_py = self.get_finalized_command('build_py')
|
|
||||||
package_dir = build_py.get_package_dir(package)
|
|
||||||
file_name = module_path[-1] + '.py'
|
|
||||||
generate_mod(os.path.join(package_dir, file_name))
|
|
||||||
dist.cmdclass['build_ext'] = build_ext_make_mod
|
|
||||||
|
|
||||||
def cffi_modules(dist, attr, value):
|
|
||||||
assert attr == 'cffi_modules'
|
|
||||||
if isinstance(value, basestring):
|
|
||||||
value = [value]
|
|
||||||
|
|
||||||
for cffi_module in value:
|
|
||||||
add_cffi_module(dist, cffi_module)
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,675 +0,0 @@
|
|||||||
#
|
|
||||||
# DEPRECATED: implementation for ffi.verify()
|
|
||||||
#
|
|
||||||
import sys, os
|
|
||||||
import types
|
|
||||||
|
|
||||||
from . import model
|
|
||||||
from .error import VerificationError
|
|
||||||
|
|
||||||
|
|
||||||
class VGenericEngine(object):
|
|
||||||
_class_key = 'g'
|
|
||||||
_gen_python_module = False
|
|
||||||
|
|
||||||
def __init__(self, verifier):
|
|
||||||
self.verifier = verifier
|
|
||||||
self.ffi = verifier.ffi
|
|
||||||
self.export_symbols = []
|
|
||||||
self._struct_pending_verification = {}
|
|
||||||
|
|
||||||
def patch_extension_kwds(self, kwds):
|
|
||||||
# add 'export_symbols' to the dictionary. Note that we add the
|
|
||||||
# list before filling it. When we fill it, it will thus also show
|
|
||||||
# up in kwds['export_symbols'].
|
|
||||||
kwds.setdefault('export_symbols', self.export_symbols)
|
|
||||||
|
|
||||||
def find_module(self, module_name, path, so_suffixes):
|
|
||||||
for so_suffix in so_suffixes:
|
|
||||||
basename = module_name + so_suffix
|
|
||||||
if path is None:
|
|
||||||
path = sys.path
|
|
||||||
for dirname in path:
|
|
||||||
filename = os.path.join(dirname, basename)
|
|
||||||
if os.path.isfile(filename):
|
|
||||||
return filename
|
|
||||||
|
|
||||||
def collect_types(self):
|
|
||||||
pass # not needed in the generic engine
|
|
||||||
|
|
||||||
def _prnt(self, what=''):
|
|
||||||
self._f.write(what + '\n')
|
|
||||||
|
|
||||||
def write_source_to_f(self):
|
|
||||||
prnt = self._prnt
|
|
||||||
# first paste some standard set of lines that are mostly '#include'
|
|
||||||
prnt(cffimod_header)
|
|
||||||
# then paste the C source given by the user, verbatim.
|
|
||||||
prnt(self.verifier.preamble)
|
|
||||||
#
|
|
||||||
# call generate_gen_xxx_decl(), for every xxx found from
|
|
||||||
# ffi._parser._declarations. This generates all the functions.
|
|
||||||
self._generate('decl')
|
|
||||||
#
|
|
||||||
# on Windows, distutils insists on putting init_cffi_xyz in
|
|
||||||
# 'export_symbols', so instead of fighting it, just give up and
|
|
||||||
# give it one
|
|
||||||
if sys.platform == 'win32':
|
|
||||||
if sys.version_info >= (3,):
|
|
||||||
prefix = 'PyInit_'
|
|
||||||
else:
|
|
||||||
prefix = 'init'
|
|
||||||
modname = self.verifier.get_module_name()
|
|
||||||
prnt("void %s%s(void) { }\n" % (prefix, modname))
|
|
||||||
|
|
||||||
def load_library(self, flags=0):
|
|
||||||
# import it with the CFFI backend
|
|
||||||
backend = self.ffi._backend
|
|
||||||
# needs to make a path that contains '/', on Posix
|
|
||||||
filename = os.path.join(os.curdir, self.verifier.modulefilename)
|
|
||||||
module = backend.load_library(filename, flags)
|
|
||||||
#
|
|
||||||
# call loading_gen_struct() to get the struct layout inferred by
|
|
||||||
# the C compiler
|
|
||||||
self._load(module, 'loading')
|
|
||||||
|
|
||||||
# build the FFILibrary class and instance, this is a module subclass
|
|
||||||
# because modules are expected to have usually-constant-attributes and
|
|
||||||
# in PyPy this means the JIT is able to treat attributes as constant,
|
|
||||||
# which we want.
|
|
||||||
class FFILibrary(types.ModuleType):
|
|
||||||
_cffi_generic_module = module
|
|
||||||
_cffi_ffi = self.ffi
|
|
||||||
_cffi_dir = []
|
|
||||||
def __dir__(self):
|
|
||||||
return FFILibrary._cffi_dir
|
|
||||||
library = FFILibrary("")
|
|
||||||
#
|
|
||||||
# finally, call the loaded_gen_xxx() functions. This will set
|
|
||||||
# up the 'library' object.
|
|
||||||
self._load(module, 'loaded', library=library)
|
|
||||||
return library
|
|
||||||
|
|
||||||
def _get_declarations(self):
|
|
||||||
lst = [(key, tp) for (key, (tp, qual)) in
|
|
||||||
self.ffi._parser._declarations.items()]
|
|
||||||
lst.sort()
|
|
||||||
return lst
|
|
||||||
|
|
||||||
def _generate(self, step_name):
|
|
||||||
for name, tp in self._get_declarations():
|
|
||||||
kind, realname = name.split(' ', 1)
|
|
||||||
try:
|
|
||||||
method = getattr(self, '_generate_gen_%s_%s' % (kind,
|
|
||||||
step_name))
|
|
||||||
except AttributeError:
|
|
||||||
raise VerificationError(
|
|
||||||
"not implemented in verify(): %r" % name)
|
|
||||||
try:
|
|
||||||
method(tp, realname)
|
|
||||||
except Exception as e:
|
|
||||||
model.attach_exception_info(e, name)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _load(self, module, step_name, **kwds):
|
|
||||||
for name, tp in self._get_declarations():
|
|
||||||
kind, realname = name.split(' ', 1)
|
|
||||||
method = getattr(self, '_%s_gen_%s' % (step_name, kind))
|
|
||||||
try:
|
|
||||||
method(tp, realname, module, **kwds)
|
|
||||||
except Exception as e:
|
|
||||||
model.attach_exception_info(e, name)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _generate_nothing(self, tp, name):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _loaded_noop(self, tp, name, module, **kwds):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
# typedefs: generates no code so far
|
|
||||||
|
|
||||||
_generate_gen_typedef_decl = _generate_nothing
|
|
||||||
_loading_gen_typedef = _loaded_noop
|
|
||||||
_loaded_gen_typedef = _loaded_noop
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
# function declarations
|
|
||||||
|
|
||||||
def _generate_gen_function_decl(self, tp, name):
|
|
||||||
assert isinstance(tp, model.FunctionPtrType)
|
|
||||||
if tp.ellipsis:
|
|
||||||
# cannot support vararg functions better than this: check for its
|
|
||||||
# exact type (including the fixed arguments), and build it as a
|
|
||||||
# constant function pointer (no _cffi_f_%s wrapper)
|
|
||||||
self._generate_gen_const(False, name, tp)
|
|
||||||
return
|
|
||||||
prnt = self._prnt
|
|
||||||
numargs = len(tp.args)
|
|
||||||
argnames = []
|
|
||||||
for i, type in enumerate(tp.args):
|
|
||||||
indirection = ''
|
|
||||||
if isinstance(type, model.StructOrUnion):
|
|
||||||
indirection = '*'
|
|
||||||
argnames.append('%sx%d' % (indirection, i))
|
|
||||||
context = 'argument of %s' % name
|
|
||||||
arglist = [type.get_c_name(' %s' % arg, context)
|
|
||||||
for type, arg in zip(tp.args, argnames)]
|
|
||||||
tpresult = tp.result
|
|
||||||
if isinstance(tpresult, model.StructOrUnion):
|
|
||||||
arglist.insert(0, tpresult.get_c_name(' *r', context))
|
|
||||||
tpresult = model.void_type
|
|
||||||
arglist = ', '.join(arglist) or 'void'
|
|
||||||
wrappername = '_cffi_f_%s' % name
|
|
||||||
self.export_symbols.append(wrappername)
|
|
||||||
if tp.abi:
|
|
||||||
abi = tp.abi + ' '
|
|
||||||
else:
|
|
||||||
abi = ''
|
|
||||||
funcdecl = ' %s%s(%s)' % (abi, wrappername, arglist)
|
|
||||||
context = 'result of %s' % name
|
|
||||||
prnt(tpresult.get_c_name(funcdecl, context))
|
|
||||||
prnt('{')
|
|
||||||
#
|
|
||||||
if isinstance(tp.result, model.StructOrUnion):
|
|
||||||
result_code = '*r = '
|
|
||||||
elif not isinstance(tp.result, model.VoidType):
|
|
||||||
result_code = 'return '
|
|
||||||
else:
|
|
||||||
result_code = ''
|
|
||||||
prnt(' %s%s(%s);' % (result_code, name, ', '.join(argnames)))
|
|
||||||
prnt('}')
|
|
||||||
prnt()
|
|
||||||
|
|
||||||
_loading_gen_function = _loaded_noop
|
|
||||||
|
|
||||||
def _loaded_gen_function(self, tp, name, module, library):
|
|
||||||
assert isinstance(tp, model.FunctionPtrType)
|
|
||||||
if tp.ellipsis:
|
|
||||||
newfunction = self._load_constant(False, tp, name, module)
|
|
||||||
else:
|
|
||||||
indirections = []
|
|
||||||
base_tp = tp
|
|
||||||
if (any(isinstance(typ, model.StructOrUnion) for typ in tp.args)
|
|
||||||
or isinstance(tp.result, model.StructOrUnion)):
|
|
||||||
indirect_args = []
|
|
||||||
for i, typ in enumerate(tp.args):
|
|
||||||
if isinstance(typ, model.StructOrUnion):
|
|
||||||
typ = model.PointerType(typ)
|
|
||||||
indirections.append((i, typ))
|
|
||||||
indirect_args.append(typ)
|
|
||||||
indirect_result = tp.result
|
|
||||||
if isinstance(indirect_result, model.StructOrUnion):
|
|
||||||
if indirect_result.fldtypes is None:
|
|
||||||
raise TypeError("'%s' is used as result type, "
|
|
||||||
"but is opaque" % (
|
|
||||||
indirect_result._get_c_name(),))
|
|
||||||
indirect_result = model.PointerType(indirect_result)
|
|
||||||
indirect_args.insert(0, indirect_result)
|
|
||||||
indirections.insert(0, ("result", indirect_result))
|
|
||||||
indirect_result = model.void_type
|
|
||||||
tp = model.FunctionPtrType(tuple(indirect_args),
|
|
||||||
indirect_result, tp.ellipsis)
|
|
||||||
BFunc = self.ffi._get_cached_btype(tp)
|
|
||||||
wrappername = '_cffi_f_%s' % name
|
|
||||||
newfunction = module.load_function(BFunc, wrappername)
|
|
||||||
for i, typ in indirections:
|
|
||||||
newfunction = self._make_struct_wrapper(newfunction, i, typ,
|
|
||||||
base_tp)
|
|
||||||
setattr(library, name, newfunction)
|
|
||||||
type(library)._cffi_dir.append(name)
|
|
||||||
|
|
||||||
def _make_struct_wrapper(self, oldfunc, i, tp, base_tp):
|
|
||||||
backend = self.ffi._backend
|
|
||||||
BType = self.ffi._get_cached_btype(tp)
|
|
||||||
if i == "result":
|
|
||||||
ffi = self.ffi
|
|
||||||
def newfunc(*args):
|
|
||||||
res = ffi.new(BType)
|
|
||||||
oldfunc(res, *args)
|
|
||||||
return res[0]
|
|
||||||
else:
|
|
||||||
def newfunc(*args):
|
|
||||||
args = args[:i] + (backend.newp(BType, args[i]),) + args[i+1:]
|
|
||||||
return oldfunc(*args)
|
|
||||||
newfunc._cffi_base_type = base_tp
|
|
||||||
return newfunc
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
# named structs
|
|
||||||
|
|
||||||
def _generate_gen_struct_decl(self, tp, name):
|
|
||||||
assert name == tp.name
|
|
||||||
self._generate_struct_or_union_decl(tp, 'struct', name)
|
|
||||||
|
|
||||||
def _loading_gen_struct(self, tp, name, module):
|
|
||||||
self._loading_struct_or_union(tp, 'struct', name, module)
|
|
||||||
|
|
||||||
def _loaded_gen_struct(self, tp, name, module, **kwds):
|
|
||||||
self._loaded_struct_or_union(tp)
|
|
||||||
|
|
||||||
def _generate_gen_union_decl(self, tp, name):
|
|
||||||
assert name == tp.name
|
|
||||||
self._generate_struct_or_union_decl(tp, 'union', name)
|
|
||||||
|
|
||||||
def _loading_gen_union(self, tp, name, module):
|
|
||||||
self._loading_struct_or_union(tp, 'union', name, module)
|
|
||||||
|
|
||||||
def _loaded_gen_union(self, tp, name, module, **kwds):
|
|
||||||
self._loaded_struct_or_union(tp)
|
|
||||||
|
|
||||||
def _generate_struct_or_union_decl(self, tp, prefix, name):
|
|
||||||
if tp.fldnames is None:
|
|
||||||
return # nothing to do with opaque structs
|
|
||||||
checkfuncname = '_cffi_check_%s_%s' % (prefix, name)
|
|
||||||
layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name)
|
|
||||||
cname = ('%s %s' % (prefix, name)).strip()
|
|
||||||
#
|
|
||||||
prnt = self._prnt
|
|
||||||
prnt('static void %s(%s *p)' % (checkfuncname, cname))
|
|
||||||
prnt('{')
|
|
||||||
prnt(' /* only to generate compile-time warnings or errors */')
|
|
||||||
prnt(' (void)p;')
|
|
||||||
for fname, ftype, fbitsize, fqual in tp.enumfields():
|
|
||||||
if (isinstance(ftype, model.PrimitiveType)
|
|
||||||
and ftype.is_integer_type()) or fbitsize >= 0:
|
|
||||||
# accept all integers, but complain on float or double
|
|
||||||
prnt(' (void)((p->%s) << 1);' % fname)
|
|
||||||
else:
|
|
||||||
# only accept exactly the type declared.
|
|
||||||
try:
|
|
||||||
prnt(' { %s = &p->%s; (void)tmp; }' % (
|
|
||||||
ftype.get_c_name('*tmp', 'field %r'%fname, quals=fqual),
|
|
||||||
fname))
|
|
||||||
except VerificationError as e:
|
|
||||||
prnt(' /* %s */' % str(e)) # cannot verify it, ignore
|
|
||||||
prnt('}')
|
|
||||||
self.export_symbols.append(layoutfuncname)
|
|
||||||
prnt('intptr_t %s(intptr_t i)' % (layoutfuncname,))
|
|
||||||
prnt('{')
|
|
||||||
prnt(' struct _cffi_aligncheck { char x; %s y; };' % cname)
|
|
||||||
prnt(' static intptr_t nums[] = {')
|
|
||||||
prnt(' sizeof(%s),' % cname)
|
|
||||||
prnt(' offsetof(struct _cffi_aligncheck, y),')
|
|
||||||
for fname, ftype, fbitsize, fqual in tp.enumfields():
|
|
||||||
if fbitsize >= 0:
|
|
||||||
continue # xxx ignore fbitsize for now
|
|
||||||
prnt(' offsetof(%s, %s),' % (cname, fname))
|
|
||||||
if isinstance(ftype, model.ArrayType) and ftype.length is None:
|
|
||||||
prnt(' 0, /* %s */' % ftype._get_c_name())
|
|
||||||
else:
|
|
||||||
prnt(' sizeof(((%s *)0)->%s),' % (cname, fname))
|
|
||||||
prnt(' -1')
|
|
||||||
prnt(' };')
|
|
||||||
prnt(' return nums[i];')
|
|
||||||
prnt(' /* the next line is not executed, but compiled */')
|
|
||||||
prnt(' %s(0);' % (checkfuncname,))
|
|
||||||
prnt('}')
|
|
||||||
prnt()
|
|
||||||
|
|
||||||
def _loading_struct_or_union(self, tp, prefix, name, module):
|
|
||||||
if tp.fldnames is None:
|
|
||||||
return # nothing to do with opaque structs
|
|
||||||
layoutfuncname = '_cffi_layout_%s_%s' % (prefix, name)
|
|
||||||
#
|
|
||||||
BFunc = self.ffi._typeof_locked("intptr_t(*)(intptr_t)")[0]
|
|
||||||
function = module.load_function(BFunc, layoutfuncname)
|
|
||||||
layout = []
|
|
||||||
num = 0
|
|
||||||
while True:
|
|
||||||
x = function(num)
|
|
||||||
if x < 0: break
|
|
||||||
layout.append(x)
|
|
||||||
num += 1
|
|
||||||
if isinstance(tp, model.StructOrUnion) and tp.partial:
|
|
||||||
# use the function()'s sizes and offsets to guide the
|
|
||||||
# layout of the struct
|
|
||||||
totalsize = layout[0]
|
|
||||||
totalalignment = layout[1]
|
|
||||||
fieldofs = layout[2::2]
|
|
||||||
fieldsize = layout[3::2]
|
|
||||||
tp.force_flatten()
|
|
||||||
assert len(fieldofs) == len(fieldsize) == len(tp.fldnames)
|
|
||||||
tp.fixedlayout = fieldofs, fieldsize, totalsize, totalalignment
|
|
||||||
else:
|
|
||||||
cname = ('%s %s' % (prefix, name)).strip()
|
|
||||||
self._struct_pending_verification[tp] = layout, cname
|
|
||||||
|
|
||||||
def _loaded_struct_or_union(self, tp):
|
|
||||||
if tp.fldnames is None:
|
|
||||||
return # nothing to do with opaque structs
|
|
||||||
self.ffi._get_cached_btype(tp) # force 'fixedlayout' to be considered
|
|
||||||
|
|
||||||
if tp in self._struct_pending_verification:
|
|
||||||
# check that the layout sizes and offsets match the real ones
|
|
||||||
def check(realvalue, expectedvalue, msg):
|
|
||||||
if realvalue != expectedvalue:
|
|
||||||
raise VerificationError(
|
|
||||||
"%s (we have %d, but C compiler says %d)"
|
|
||||||
% (msg, expectedvalue, realvalue))
|
|
||||||
ffi = self.ffi
|
|
||||||
BStruct = ffi._get_cached_btype(tp)
|
|
||||||
layout, cname = self._struct_pending_verification.pop(tp)
|
|
||||||
check(layout[0], ffi.sizeof(BStruct), "wrong total size")
|
|
||||||
check(layout[1], ffi.alignof(BStruct), "wrong total alignment")
|
|
||||||
i = 2
|
|
||||||
for fname, ftype, fbitsize, fqual in tp.enumfields():
|
|
||||||
if fbitsize >= 0:
|
|
||||||
continue # xxx ignore fbitsize for now
|
|
||||||
check(layout[i], ffi.offsetof(BStruct, fname),
|
|
||||||
"wrong offset for field %r" % (fname,))
|
|
||||||
if layout[i+1] != 0:
|
|
||||||
BField = ffi._get_cached_btype(ftype)
|
|
||||||
check(layout[i+1], ffi.sizeof(BField),
|
|
||||||
"wrong size for field %r" % (fname,))
|
|
||||||
i += 2
|
|
||||||
assert i == len(layout)
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
# 'anonymous' declarations. These are produced for anonymous structs
|
|
||||||
# or unions; the 'name' is obtained by a typedef.
|
|
||||||
|
|
||||||
def _generate_gen_anonymous_decl(self, tp, name):
|
|
||||||
if isinstance(tp, model.EnumType):
|
|
||||||
self._generate_gen_enum_decl(tp, name, '')
|
|
||||||
else:
|
|
||||||
self._generate_struct_or_union_decl(tp, '', name)
|
|
||||||
|
|
||||||
def _loading_gen_anonymous(self, tp, name, module):
|
|
||||||
if isinstance(tp, model.EnumType):
|
|
||||||
self._loading_gen_enum(tp, name, module, '')
|
|
||||||
else:
|
|
||||||
self._loading_struct_or_union(tp, '', name, module)
|
|
||||||
|
|
||||||
def _loaded_gen_anonymous(self, tp, name, module, **kwds):
|
|
||||||
if isinstance(tp, model.EnumType):
|
|
||||||
self._loaded_gen_enum(tp, name, module, **kwds)
|
|
||||||
else:
|
|
||||||
self._loaded_struct_or_union(tp)
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
# constants, likely declared with '#define'
|
|
||||||
|
|
||||||
def _generate_gen_const(self, is_int, name, tp=None, category='const',
|
|
||||||
check_value=None):
|
|
||||||
prnt = self._prnt
|
|
||||||
funcname = '_cffi_%s_%s' % (category, name)
|
|
||||||
self.export_symbols.append(funcname)
|
|
||||||
if check_value is not None:
|
|
||||||
assert is_int
|
|
||||||
assert category == 'const'
|
|
||||||
prnt('int %s(char *out_error)' % funcname)
|
|
||||||
prnt('{')
|
|
||||||
self._check_int_constant_value(name, check_value)
|
|
||||||
prnt(' return 0;')
|
|
||||||
prnt('}')
|
|
||||||
elif is_int:
|
|
||||||
assert category == 'const'
|
|
||||||
prnt('int %s(long long *out_value)' % funcname)
|
|
||||||
prnt('{')
|
|
||||||
prnt(' *out_value = (long long)(%s);' % (name,))
|
|
||||||
prnt(' return (%s) <= 0;' % (name,))
|
|
||||||
prnt('}')
|
|
||||||
else:
|
|
||||||
assert tp is not None
|
|
||||||
assert check_value is None
|
|
||||||
if category == 'var':
|
|
||||||
ampersand = '&'
|
|
||||||
else:
|
|
||||||
ampersand = ''
|
|
||||||
extra = ''
|
|
||||||
if category == 'const' and isinstance(tp, model.StructOrUnion):
|
|
||||||
extra = 'const *'
|
|
||||||
ampersand = '&'
|
|
||||||
prnt(tp.get_c_name(' %s%s(void)' % (extra, funcname), name))
|
|
||||||
prnt('{')
|
|
||||||
prnt(' return (%s%s);' % (ampersand, name))
|
|
||||||
prnt('}')
|
|
||||||
prnt()
|
|
||||||
|
|
||||||
def _generate_gen_constant_decl(self, tp, name):
|
|
||||||
is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type()
|
|
||||||
self._generate_gen_const(is_int, name, tp)
|
|
||||||
|
|
||||||
_loading_gen_constant = _loaded_noop
|
|
||||||
|
|
||||||
def _load_constant(self, is_int, tp, name, module, check_value=None):
|
|
||||||
funcname = '_cffi_const_%s' % name
|
|
||||||
if check_value is not None:
|
|
||||||
assert is_int
|
|
||||||
self._load_known_int_constant(module, funcname)
|
|
||||||
value = check_value
|
|
||||||
elif is_int:
|
|
||||||
BType = self.ffi._typeof_locked("long long*")[0]
|
|
||||||
BFunc = self.ffi._typeof_locked("int(*)(long long*)")[0]
|
|
||||||
function = module.load_function(BFunc, funcname)
|
|
||||||
p = self.ffi.new(BType)
|
|
||||||
negative = function(p)
|
|
||||||
value = int(p[0])
|
|
||||||
if value < 0 and not negative:
|
|
||||||
BLongLong = self.ffi._typeof_locked("long long")[0]
|
|
||||||
value += (1 << (8*self.ffi.sizeof(BLongLong)))
|
|
||||||
else:
|
|
||||||
assert check_value is None
|
|
||||||
fntypeextra = '(*)(void)'
|
|
||||||
if isinstance(tp, model.StructOrUnion):
|
|
||||||
fntypeextra = '*' + fntypeextra
|
|
||||||
BFunc = self.ffi._typeof_locked(tp.get_c_name(fntypeextra, name))[0]
|
|
||||||
function = module.load_function(BFunc, funcname)
|
|
||||||
value = function()
|
|
||||||
if isinstance(tp, model.StructOrUnion):
|
|
||||||
value = value[0]
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _loaded_gen_constant(self, tp, name, module, library):
|
|
||||||
is_int = isinstance(tp, model.PrimitiveType) and tp.is_integer_type()
|
|
||||||
value = self._load_constant(is_int, tp, name, module)
|
|
||||||
setattr(library, name, value)
|
|
||||||
type(library)._cffi_dir.append(name)
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
# enums
|
|
||||||
|
|
||||||
def _check_int_constant_value(self, name, value):
|
|
||||||
prnt = self._prnt
|
|
||||||
if value <= 0:
|
|
||||||
prnt(' if ((%s) > 0 || (long)(%s) != %dL) {' % (
|
|
||||||
name, name, value))
|
|
||||||
else:
|
|
||||||
prnt(' if ((%s) <= 0 || (unsigned long)(%s) != %dUL) {' % (
|
|
||||||
name, name, value))
|
|
||||||
prnt(' char buf[64];')
|
|
||||||
prnt(' if ((%s) <= 0)' % name)
|
|
||||||
prnt(' sprintf(buf, "%%ld", (long)(%s));' % name)
|
|
||||||
prnt(' else')
|
|
||||||
prnt(' sprintf(buf, "%%lu", (unsigned long)(%s));' %
|
|
||||||
name)
|
|
||||||
prnt(' sprintf(out_error, "%s has the real value %s, not %s",')
|
|
||||||
prnt(' "%s", buf, "%d");' % (name[:100], value))
|
|
||||||
prnt(' return -1;')
|
|
||||||
prnt(' }')
|
|
||||||
|
|
||||||
def _load_known_int_constant(self, module, funcname):
|
|
||||||
BType = self.ffi._typeof_locked("char[]")[0]
|
|
||||||
BFunc = self.ffi._typeof_locked("int(*)(char*)")[0]
|
|
||||||
function = module.load_function(BFunc, funcname)
|
|
||||||
p = self.ffi.new(BType, 256)
|
|
||||||
if function(p) < 0:
|
|
||||||
error = self.ffi.string(p)
|
|
||||||
if sys.version_info >= (3,):
|
|
||||||
error = str(error, 'utf-8')
|
|
||||||
raise VerificationError(error)
|
|
||||||
|
|
||||||
def _enum_funcname(self, prefix, name):
|
|
||||||
# "$enum_$1" => "___D_enum____D_1"
|
|
||||||
name = name.replace('$', '___D_')
|
|
||||||
return '_cffi_e_%s_%s' % (prefix, name)
|
|
||||||
|
|
||||||
def _generate_gen_enum_decl(self, tp, name, prefix='enum'):
|
|
||||||
if tp.partial:
|
|
||||||
for enumerator in tp.enumerators:
|
|
||||||
self._generate_gen_const(True, enumerator)
|
|
||||||
return
|
|
||||||
#
|
|
||||||
funcname = self._enum_funcname(prefix, name)
|
|
||||||
self.export_symbols.append(funcname)
|
|
||||||
prnt = self._prnt
|
|
||||||
prnt('int %s(char *out_error)' % funcname)
|
|
||||||
prnt('{')
|
|
||||||
for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues):
|
|
||||||
self._check_int_constant_value(enumerator, enumvalue)
|
|
||||||
prnt(' return 0;')
|
|
||||||
prnt('}')
|
|
||||||
prnt()
|
|
||||||
|
|
||||||
def _loading_gen_enum(self, tp, name, module, prefix='enum'):
|
|
||||||
if tp.partial:
|
|
||||||
enumvalues = [self._load_constant(True, tp, enumerator, module)
|
|
||||||
for enumerator in tp.enumerators]
|
|
||||||
tp.enumvalues = tuple(enumvalues)
|
|
||||||
tp.partial_resolved = True
|
|
||||||
else:
|
|
||||||
funcname = self._enum_funcname(prefix, name)
|
|
||||||
self._load_known_int_constant(module, funcname)
|
|
||||||
|
|
||||||
def _loaded_gen_enum(self, tp, name, module, library):
|
|
||||||
for enumerator, enumvalue in zip(tp.enumerators, tp.enumvalues):
|
|
||||||
setattr(library, enumerator, enumvalue)
|
|
||||||
type(library)._cffi_dir.append(enumerator)
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
# macros: for now only for integers
|
|
||||||
|
|
||||||
def _generate_gen_macro_decl(self, tp, name):
|
|
||||||
if tp == '...':
|
|
||||||
check_value = None
|
|
||||||
else:
|
|
||||||
check_value = tp # an integer
|
|
||||||
self._generate_gen_const(True, name, check_value=check_value)
|
|
||||||
|
|
||||||
_loading_gen_macro = _loaded_noop
|
|
||||||
|
|
||||||
def _loaded_gen_macro(self, tp, name, module, library):
|
|
||||||
if tp == '...':
|
|
||||||
check_value = None
|
|
||||||
else:
|
|
||||||
check_value = tp # an integer
|
|
||||||
value = self._load_constant(True, tp, name, module,
|
|
||||||
check_value=check_value)
|
|
||||||
setattr(library, name, value)
|
|
||||||
type(library)._cffi_dir.append(name)
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
# global variables
|
|
||||||
|
|
||||||
def _generate_gen_variable_decl(self, tp, name):
|
|
||||||
if isinstance(tp, model.ArrayType):
|
|
||||||
if tp.length_is_unknown():
|
|
||||||
prnt = self._prnt
|
|
||||||
funcname = '_cffi_sizeof_%s' % (name,)
|
|
||||||
self.export_symbols.append(funcname)
|
|
||||||
prnt("size_t %s(void)" % funcname)
|
|
||||||
prnt("{")
|
|
||||||
prnt(" return sizeof(%s);" % (name,))
|
|
||||||
prnt("}")
|
|
||||||
tp_ptr = model.PointerType(tp.item)
|
|
||||||
self._generate_gen_const(False, name, tp_ptr)
|
|
||||||
else:
|
|
||||||
tp_ptr = model.PointerType(tp)
|
|
||||||
self._generate_gen_const(False, name, tp_ptr, category='var')
|
|
||||||
|
|
||||||
_loading_gen_variable = _loaded_noop
|
|
||||||
|
|
||||||
def _loaded_gen_variable(self, tp, name, module, library):
|
|
||||||
if isinstance(tp, model.ArrayType): # int a[5] is "constant" in the
|
|
||||||
# sense that "a=..." is forbidden
|
|
||||||
if tp.length_is_unknown():
|
|
||||||
funcname = '_cffi_sizeof_%s' % (name,)
|
|
||||||
BFunc = self.ffi._typeof_locked('size_t(*)(void)')[0]
|
|
||||||
function = module.load_function(BFunc, funcname)
|
|
||||||
size = function()
|
|
||||||
BItemType = self.ffi._get_cached_btype(tp.item)
|
|
||||||
length, rest = divmod(size, self.ffi.sizeof(BItemType))
|
|
||||||
if rest != 0:
|
|
||||||
raise VerificationError(
|
|
||||||
"bad size: %r does not seem to be an array of %s" %
|
|
||||||
(name, tp.item))
|
|
||||||
tp = tp.resolve_length(length)
|
|
||||||
tp_ptr = model.PointerType(tp.item)
|
|
||||||
value = self._load_constant(False, tp_ptr, name, module)
|
|
||||||
# 'value' is a <cdata 'type *'> which we have to replace with
|
|
||||||
# a <cdata 'type[N]'> if the N is actually known
|
|
||||||
if tp.length is not None:
|
|
||||||
BArray = self.ffi._get_cached_btype(tp)
|
|
||||||
value = self.ffi.cast(BArray, value)
|
|
||||||
setattr(library, name, value)
|
|
||||||
type(library)._cffi_dir.append(name)
|
|
||||||
return
|
|
||||||
# remove ptr=<cdata 'int *'> from the library instance, and replace
|
|
||||||
# it by a property on the class, which reads/writes into ptr[0].
|
|
||||||
funcname = '_cffi_var_%s' % name
|
|
||||||
BFunc = self.ffi._typeof_locked(tp.get_c_name('*(*)(void)', name))[0]
|
|
||||||
function = module.load_function(BFunc, funcname)
|
|
||||||
ptr = function()
|
|
||||||
def getter(library):
|
|
||||||
return ptr[0]
|
|
||||||
def setter(library, value):
|
|
||||||
ptr[0] = value
|
|
||||||
setattr(type(library), name, property(getter, setter))
|
|
||||||
type(library)._cffi_dir.append(name)
|
|
||||||
|
|
||||||
cffimod_header = r'''
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stddef.h>
|
|
||||||
#include <stdarg.h>
|
|
||||||
#include <errno.h>
|
|
||||||
#include <sys/types.h> /* XXX for ssize_t on some platforms */
|
|
||||||
|
|
||||||
/* this block of #ifs should be kept exactly identical between
|
|
||||||
c/_cffi_backend.c, cffi/vengine_cpy.py, cffi/vengine_gen.py
|
|
||||||
and cffi/_cffi_include.h */
|
|
||||||
#if defined(_MSC_VER)
|
|
||||||
# include <malloc.h> /* for alloca() */
|
|
||||||
# if _MSC_VER < 1600 /* MSVC < 2010 */
|
|
||||||
typedef __int8 int8_t;
|
|
||||||
typedef __int16 int16_t;
|
|
||||||
typedef __int32 int32_t;
|
|
||||||
typedef __int64 int64_t;
|
|
||||||
typedef unsigned __int8 uint8_t;
|
|
||||||
typedef unsigned __int16 uint16_t;
|
|
||||||
typedef unsigned __int32 uint32_t;
|
|
||||||
typedef unsigned __int64 uint64_t;
|
|
||||||
typedef __int8 int_least8_t;
|
|
||||||
typedef __int16 int_least16_t;
|
|
||||||
typedef __int32 int_least32_t;
|
|
||||||
typedef __int64 int_least64_t;
|
|
||||||
typedef unsigned __int8 uint_least8_t;
|
|
||||||
typedef unsigned __int16 uint_least16_t;
|
|
||||||
typedef unsigned __int32 uint_least32_t;
|
|
||||||
typedef unsigned __int64 uint_least64_t;
|
|
||||||
typedef __int8 int_fast8_t;
|
|
||||||
typedef __int16 int_fast16_t;
|
|
||||||
typedef __int32 int_fast32_t;
|
|
||||||
typedef __int64 int_fast64_t;
|
|
||||||
typedef unsigned __int8 uint_fast8_t;
|
|
||||||
typedef unsigned __int16 uint_fast16_t;
|
|
||||||
typedef unsigned __int32 uint_fast32_t;
|
|
||||||
typedef unsigned __int64 uint_fast64_t;
|
|
||||||
typedef __int64 intmax_t;
|
|
||||||
typedef unsigned __int64 uintmax_t;
|
|
||||||
# else
|
|
||||||
# include <stdint.h>
|
|
||||||
# endif
|
|
||||||
# if _MSC_VER < 1800 /* MSVC < 2013 */
|
|
||||||
# ifndef __cplusplus
|
|
||||||
typedef unsigned char _Bool;
|
|
||||||
# endif
|
|
||||||
# endif
|
|
||||||
#else
|
|
||||||
# include <stdint.h>
|
|
||||||
# if (defined (__SVR4) && defined (__sun)) || defined(_AIX) || defined(__hpux)
|
|
||||||
# include <alloca.h>
|
|
||||||
# endif
|
|
||||||
#endif
|
|
||||||
'''
|
|
||||||
@@ -1,307 +0,0 @@
|
|||||||
#
|
|
||||||
# DEPRECATED: implementation for ffi.verify()
|
|
||||||
#
|
|
||||||
import sys, os, binascii, shutil, io
|
|
||||||
from . import __version_verifier_modules__
|
|
||||||
from . import ffiplatform
|
|
||||||
from .error import VerificationError
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 3):
|
|
||||||
import importlib.machinery
|
|
||||||
def _extension_suffixes():
|
|
||||||
return importlib.machinery.EXTENSION_SUFFIXES[:]
|
|
||||||
else:
|
|
||||||
import imp
|
|
||||||
def _extension_suffixes():
|
|
||||||
return [suffix for suffix, _, type in imp.get_suffixes()
|
|
||||||
if type == imp.C_EXTENSION]
|
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info >= (3,):
|
|
||||||
NativeIO = io.StringIO
|
|
||||||
else:
|
|
||||||
class NativeIO(io.BytesIO):
|
|
||||||
def write(self, s):
|
|
||||||
if isinstance(s, unicode):
|
|
||||||
s = s.encode('ascii')
|
|
||||||
super(NativeIO, self).write(s)
|
|
||||||
|
|
||||||
|
|
||||||
class Verifier(object):
|
|
||||||
|
|
||||||
def __init__(self, ffi, preamble, tmpdir=None, modulename=None,
|
|
||||||
ext_package=None, tag='', force_generic_engine=False,
|
|
||||||
source_extension='.c', flags=None, relative_to=None, **kwds):
|
|
||||||
if ffi._parser._uses_new_feature:
|
|
||||||
raise VerificationError(
|
|
||||||
"feature not supported with ffi.verify(), but only "
|
|
||||||
"with ffi.set_source(): %s" % (ffi._parser._uses_new_feature,))
|
|
||||||
self.ffi = ffi
|
|
||||||
self.preamble = preamble
|
|
||||||
if not modulename:
|
|
||||||
flattened_kwds = ffiplatform.flatten(kwds)
|
|
||||||
vengine_class = _locate_engine_class(ffi, force_generic_engine)
|
|
||||||
self._vengine = vengine_class(self)
|
|
||||||
self._vengine.patch_extension_kwds(kwds)
|
|
||||||
self.flags = flags
|
|
||||||
self.kwds = self.make_relative_to(kwds, relative_to)
|
|
||||||
#
|
|
||||||
if modulename:
|
|
||||||
if tag:
|
|
||||||
raise TypeError("can't specify both 'modulename' and 'tag'")
|
|
||||||
else:
|
|
||||||
key = '\x00'.join(['%d.%d' % sys.version_info[:2],
|
|
||||||
__version_verifier_modules__,
|
|
||||||
preamble, flattened_kwds] +
|
|
||||||
ffi._cdefsources)
|
|
||||||
if sys.version_info >= (3,):
|
|
||||||
key = key.encode('utf-8')
|
|
||||||
k1 = hex(binascii.crc32(key[0::2]) & 0xffffffff)
|
|
||||||
k1 = k1.lstrip('0x').rstrip('L')
|
|
||||||
k2 = hex(binascii.crc32(key[1::2]) & 0xffffffff)
|
|
||||||
k2 = k2.lstrip('0').rstrip('L')
|
|
||||||
modulename = '_cffi_%s_%s%s%s' % (tag, self._vengine._class_key,
|
|
||||||
k1, k2)
|
|
||||||
suffix = _get_so_suffixes()[0]
|
|
||||||
self.tmpdir = tmpdir or _caller_dir_pycache()
|
|
||||||
self.sourcefilename = os.path.join(self.tmpdir, modulename + source_extension)
|
|
||||||
self.modulefilename = os.path.join(self.tmpdir, modulename + suffix)
|
|
||||||
self.ext_package = ext_package
|
|
||||||
self._has_source = False
|
|
||||||
self._has_module = False
|
|
||||||
|
|
||||||
def write_source(self, file=None):
|
|
||||||
"""Write the C source code. It is produced in 'self.sourcefilename',
|
|
||||||
which can be tweaked beforehand."""
|
|
||||||
with self.ffi._lock:
|
|
||||||
if self._has_source and file is None:
|
|
||||||
raise VerificationError(
|
|
||||||
"source code already written")
|
|
||||||
self._write_source(file)
|
|
||||||
|
|
||||||
def compile_module(self):
|
|
||||||
"""Write the C source code (if not done already) and compile it.
|
|
||||||
This produces a dynamic link library in 'self.modulefilename'."""
|
|
||||||
with self.ffi._lock:
|
|
||||||
if self._has_module:
|
|
||||||
raise VerificationError("module already compiled")
|
|
||||||
if not self._has_source:
|
|
||||||
self._write_source()
|
|
||||||
self._compile_module()
|
|
||||||
|
|
||||||
def load_library(self):
|
|
||||||
"""Get a C module from this Verifier instance.
|
|
||||||
Returns an instance of a FFILibrary class that behaves like the
|
|
||||||
objects returned by ffi.dlopen(), but that delegates all
|
|
||||||
operations to the C module. If necessary, the C code is written
|
|
||||||
and compiled first.
|
|
||||||
"""
|
|
||||||
with self.ffi._lock:
|
|
||||||
if not self._has_module:
|
|
||||||
self._locate_module()
|
|
||||||
if not self._has_module:
|
|
||||||
if not self._has_source:
|
|
||||||
self._write_source()
|
|
||||||
self._compile_module()
|
|
||||||
return self._load_library()
|
|
||||||
|
|
||||||
def get_module_name(self):
|
|
||||||
basename = os.path.basename(self.modulefilename)
|
|
||||||
# kill both the .so extension and the other .'s, as introduced
|
|
||||||
# by Python 3: 'basename.cpython-33m.so'
|
|
||||||
basename = basename.split('.', 1)[0]
|
|
||||||
# and the _d added in Python 2 debug builds --- but try to be
|
|
||||||
# conservative and not kill a legitimate _d
|
|
||||||
if basename.endswith('_d') and hasattr(sys, 'gettotalrefcount'):
|
|
||||||
basename = basename[:-2]
|
|
||||||
return basename
|
|
||||||
|
|
||||||
def get_extension(self):
|
|
||||||
ffiplatform._hack_at_distutils() # backward compatibility hack
|
|
||||||
if not self._has_source:
|
|
||||||
with self.ffi._lock:
|
|
||||||
if not self._has_source:
|
|
||||||
self._write_source()
|
|
||||||
sourcename = ffiplatform.maybe_relative_path(self.sourcefilename)
|
|
||||||
modname = self.get_module_name()
|
|
||||||
return ffiplatform.get_extension(sourcename, modname, **self.kwds)
|
|
||||||
|
|
||||||
def generates_python_module(self):
|
|
||||||
return self._vengine._gen_python_module
|
|
||||||
|
|
||||||
def make_relative_to(self, kwds, relative_to):
|
|
||||||
if relative_to and os.path.dirname(relative_to):
|
|
||||||
dirname = os.path.dirname(relative_to)
|
|
||||||
kwds = kwds.copy()
|
|
||||||
for key in ffiplatform.LIST_OF_FILE_NAMES:
|
|
||||||
if key in kwds:
|
|
||||||
lst = kwds[key]
|
|
||||||
if not isinstance(lst, (list, tuple)):
|
|
||||||
raise TypeError("keyword '%s' should be a list or tuple"
|
|
||||||
% (key,))
|
|
||||||
lst = [os.path.join(dirname, fn) for fn in lst]
|
|
||||||
kwds[key] = lst
|
|
||||||
return kwds
|
|
||||||
|
|
||||||
# ----------
|
|
||||||
|
|
||||||
def _locate_module(self):
|
|
||||||
if not os.path.isfile(self.modulefilename):
|
|
||||||
if self.ext_package:
|
|
||||||
try:
|
|
||||||
pkg = __import__(self.ext_package, None, None, ['__doc__'])
|
|
||||||
except ImportError:
|
|
||||||
return # cannot import the package itself, give up
|
|
||||||
# (e.g. it might be called differently before installation)
|
|
||||||
path = pkg.__path__
|
|
||||||
else:
|
|
||||||
path = None
|
|
||||||
filename = self._vengine.find_module(self.get_module_name(), path,
|
|
||||||
_get_so_suffixes())
|
|
||||||
if filename is None:
|
|
||||||
return
|
|
||||||
self.modulefilename = filename
|
|
||||||
self._vengine.collect_types()
|
|
||||||
self._has_module = True
|
|
||||||
|
|
||||||
def _write_source_to(self, file):
|
|
||||||
self._vengine._f = file
|
|
||||||
try:
|
|
||||||
self._vengine.write_source_to_f()
|
|
||||||
finally:
|
|
||||||
del self._vengine._f
|
|
||||||
|
|
||||||
def _write_source(self, file=None):
|
|
||||||
if file is not None:
|
|
||||||
self._write_source_to(file)
|
|
||||||
else:
|
|
||||||
# Write our source file to an in memory file.
|
|
||||||
f = NativeIO()
|
|
||||||
self._write_source_to(f)
|
|
||||||
source_data = f.getvalue()
|
|
||||||
|
|
||||||
# Determine if this matches the current file
|
|
||||||
if os.path.exists(self.sourcefilename):
|
|
||||||
with open(self.sourcefilename, "r") as fp:
|
|
||||||
needs_written = not (fp.read() == source_data)
|
|
||||||
else:
|
|
||||||
needs_written = True
|
|
||||||
|
|
||||||
# Actually write the file out if it doesn't match
|
|
||||||
if needs_written:
|
|
||||||
_ensure_dir(self.sourcefilename)
|
|
||||||
with open(self.sourcefilename, "w") as fp:
|
|
||||||
fp.write(source_data)
|
|
||||||
|
|
||||||
# Set this flag
|
|
||||||
self._has_source = True
|
|
||||||
|
|
||||||
def _compile_module(self):
|
|
||||||
# compile this C source
|
|
||||||
tmpdir = os.path.dirname(self.sourcefilename)
|
|
||||||
outputfilename = ffiplatform.compile(tmpdir, self.get_extension())
|
|
||||||
try:
|
|
||||||
same = ffiplatform.samefile(outputfilename, self.modulefilename)
|
|
||||||
except OSError:
|
|
||||||
same = False
|
|
||||||
if not same:
|
|
||||||
_ensure_dir(self.modulefilename)
|
|
||||||
shutil.move(outputfilename, self.modulefilename)
|
|
||||||
self._has_module = True
|
|
||||||
|
|
||||||
def _load_library(self):
|
|
||||||
assert self._has_module
|
|
||||||
if self.flags is not None:
|
|
||||||
return self._vengine.load_library(self.flags)
|
|
||||||
else:
|
|
||||||
return self._vengine.load_library()
|
|
||||||
|
|
||||||
# ____________________________________________________________
|
|
||||||
|
|
||||||
_FORCE_GENERIC_ENGINE = False # for tests
|
|
||||||
|
|
||||||
def _locate_engine_class(ffi, force_generic_engine):
|
|
||||||
if _FORCE_GENERIC_ENGINE:
|
|
||||||
force_generic_engine = True
|
|
||||||
if not force_generic_engine:
|
|
||||||
if '__pypy__' in sys.builtin_module_names:
|
|
||||||
force_generic_engine = True
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
import _cffi_backend
|
|
||||||
except ImportError:
|
|
||||||
_cffi_backend = '?'
|
|
||||||
if ffi._backend is not _cffi_backend:
|
|
||||||
force_generic_engine = True
|
|
||||||
if force_generic_engine:
|
|
||||||
from . import vengine_gen
|
|
||||||
return vengine_gen.VGenericEngine
|
|
||||||
else:
|
|
||||||
from . import vengine_cpy
|
|
||||||
return vengine_cpy.VCPythonEngine
|
|
||||||
|
|
||||||
# ____________________________________________________________
|
|
||||||
|
|
||||||
_TMPDIR = None
|
|
||||||
|
|
||||||
def _caller_dir_pycache():
|
|
||||||
if _TMPDIR:
|
|
||||||
return _TMPDIR
|
|
||||||
result = os.environ.get('CFFI_TMPDIR')
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
filename = sys._getframe(2).f_code.co_filename
|
|
||||||
return os.path.abspath(os.path.join(os.path.dirname(filename),
|
|
||||||
'__pycache__'))
|
|
||||||
|
|
||||||
def set_tmpdir(dirname):
|
|
||||||
"""Set the temporary directory to use instead of __pycache__."""
|
|
||||||
global _TMPDIR
|
|
||||||
_TMPDIR = dirname
|
|
||||||
|
|
||||||
def cleanup_tmpdir(tmpdir=None, keep_so=False):
|
|
||||||
"""Clean up the temporary directory by removing all files in it
|
|
||||||
called `_cffi_*.{c,so}` as well as the `build` subdirectory."""
|
|
||||||
tmpdir = tmpdir or _caller_dir_pycache()
|
|
||||||
try:
|
|
||||||
filelist = os.listdir(tmpdir)
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
if keep_so:
|
|
||||||
suffix = '.c' # only remove .c files
|
|
||||||
else:
|
|
||||||
suffix = _get_so_suffixes()[0].lower()
|
|
||||||
for fn in filelist:
|
|
||||||
if fn.lower().startswith('_cffi_') and (
|
|
||||||
fn.lower().endswith(suffix) or fn.lower().endswith('.c')):
|
|
||||||
try:
|
|
||||||
os.unlink(os.path.join(tmpdir, fn))
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
clean_dir = [os.path.join(tmpdir, 'build')]
|
|
||||||
for dir in clean_dir:
|
|
||||||
try:
|
|
||||||
for fn in os.listdir(dir):
|
|
||||||
fn = os.path.join(dir, fn)
|
|
||||||
if os.path.isdir(fn):
|
|
||||||
clean_dir.append(fn)
|
|
||||||
else:
|
|
||||||
os.unlink(fn)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_so_suffixes():
|
|
||||||
suffixes = _extension_suffixes()
|
|
||||||
if not suffixes:
|
|
||||||
# bah, no C_EXTENSION available. Occurs on pypy without cpyext
|
|
||||||
if sys.platform == 'win32':
|
|
||||||
suffixes = [".pyd"]
|
|
||||||
else:
|
|
||||||
suffixes = [".so"]
|
|
||||||
|
|
||||||
return suffixes
|
|
||||||
|
|
||||||
def _ensure_dir(filename):
|
|
||||||
dirname = os.path.dirname(filename)
|
|
||||||
if dirname and not os.path.isdir(dirname):
|
|
||||||
os.makedirs(dirname)
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"__version__",
|
|
||||||
"__author__",
|
|
||||||
"__copyright__",
|
|
||||||
]
|
|
||||||
|
|
||||||
__version__ = "37.0.2"
|
|
||||||
|
|
||||||
__author__ = "The Python Cryptographic Authority and individual contributors"
|
|
||||||
__copyright__ = "Copyright 2013-2021 {}".format(__author__)
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from cryptography.__about__ import (
|
|
||||||
__author__,
|
|
||||||
__copyright__,
|
|
||||||
__version__,
|
|
||||||
)
|
|
||||||
from cryptography.utils import CryptographyDeprecationWarning
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"__version__",
|
|
||||||
"__author__",
|
|
||||||
"__copyright__",
|
|
||||||
]
|
|
||||||
|
|
||||||
if sys.version_info[:2] == (3, 6):
|
|
||||||
warnings.warn(
|
|
||||||
"Python 3.6 is no longer supported by the Python core team. "
|
|
||||||
"Therefore, support for it is deprecated in cryptography and will be"
|
|
||||||
" removed in a future release.",
|
|
||||||
CryptographyDeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography import utils
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.bindings.openssl.binding import (
|
|
||||||
_OpenSSLErrorWithText,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _Reasons(utils.Enum):
|
|
||||||
BACKEND_MISSING_INTERFACE = 0
|
|
||||||
UNSUPPORTED_HASH = 1
|
|
||||||
UNSUPPORTED_CIPHER = 2
|
|
||||||
UNSUPPORTED_PADDING = 3
|
|
||||||
UNSUPPORTED_MGF = 4
|
|
||||||
UNSUPPORTED_PUBLIC_KEY_ALGORITHM = 5
|
|
||||||
UNSUPPORTED_ELLIPTIC_CURVE = 6
|
|
||||||
UNSUPPORTED_SERIALIZATION = 7
|
|
||||||
UNSUPPORTED_X509 = 8
|
|
||||||
UNSUPPORTED_EXCHANGE_ALGORITHM = 9
|
|
||||||
UNSUPPORTED_DIFFIE_HELLMAN = 10
|
|
||||||
UNSUPPORTED_MAC = 11
|
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedAlgorithm(Exception):
|
|
||||||
def __init__(
|
|
||||||
self, message: str, reason: typing.Optional[_Reasons] = None
|
|
||||||
) -> None:
|
|
||||||
super(UnsupportedAlgorithm, self).__init__(message)
|
|
||||||
self._reason = reason
|
|
||||||
|
|
||||||
|
|
||||||
class AlreadyFinalized(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AlreadyUpdated(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NotYetFinalized(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidTag(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidSignature(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InternalError(Exception):
|
|
||||||
def __init__(
|
|
||||||
self, msg: str, err_code: typing.List["_OpenSSLErrorWithText"]
|
|
||||||
) -> None:
|
|
||||||
super(InternalError, self).__init__(msg)
|
|
||||||
self.err_code = err_code
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidKey(Exception):
|
|
||||||
pass
|
|
||||||
@@ -1,212 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import binascii
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography import utils
|
|
||||||
from cryptography.exceptions import InvalidSignature
|
|
||||||
from cryptography.hazmat.primitives import hashes, padding
|
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
|
||||||
from cryptography.hazmat.primitives.hmac import HMAC
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidToken(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
_MAX_CLOCK_SKEW = 60
|
|
||||||
|
|
||||||
|
|
||||||
class Fernet:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
key: typing.Union[bytes, str],
|
|
||||||
backend: typing.Any = None,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
key = base64.urlsafe_b64decode(key)
|
|
||||||
except binascii.Error as exc:
|
|
||||||
raise ValueError(
|
|
||||||
"Fernet key must be 32 url-safe base64-encoded bytes."
|
|
||||||
) from exc
|
|
||||||
if len(key) != 32:
|
|
||||||
raise ValueError(
|
|
||||||
"Fernet key must be 32 url-safe base64-encoded bytes."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._signing_key = key[:16]
|
|
||||||
self._encryption_key = key[16:]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate_key(cls) -> bytes:
|
|
||||||
return base64.urlsafe_b64encode(os.urandom(32))
|
|
||||||
|
|
||||||
def encrypt(self, data: bytes) -> bytes:
|
|
||||||
return self.encrypt_at_time(data, int(time.time()))
|
|
||||||
|
|
||||||
def encrypt_at_time(self, data: bytes, current_time: int) -> bytes:
|
|
||||||
iv = os.urandom(16)
|
|
||||||
return self._encrypt_from_parts(data, current_time, iv)
|
|
||||||
|
|
||||||
def _encrypt_from_parts(
|
|
||||||
self, data: bytes, current_time: int, iv: bytes
|
|
||||||
) -> bytes:
|
|
||||||
utils._check_bytes("data", data)
|
|
||||||
|
|
||||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
|
||||||
padded_data = padder.update(data) + padder.finalize()
|
|
||||||
encryptor = Cipher(
|
|
||||||
algorithms.AES(self._encryption_key),
|
|
||||||
modes.CBC(iv),
|
|
||||||
).encryptor()
|
|
||||||
ciphertext = encryptor.update(padded_data) + encryptor.finalize()
|
|
||||||
|
|
||||||
basic_parts = (
|
|
||||||
b"\x80"
|
|
||||||
+ current_time.to_bytes(length=8, byteorder="big")
|
|
||||||
+ iv
|
|
||||||
+ ciphertext
|
|
||||||
)
|
|
||||||
|
|
||||||
h = HMAC(self._signing_key, hashes.SHA256())
|
|
||||||
h.update(basic_parts)
|
|
||||||
hmac = h.finalize()
|
|
||||||
return base64.urlsafe_b64encode(basic_parts + hmac)
|
|
||||||
|
|
||||||
def decrypt(self, token: bytes, ttl: typing.Optional[int] = None) -> bytes:
|
|
||||||
timestamp, data = Fernet._get_unverified_token_data(token)
|
|
||||||
if ttl is None:
|
|
||||||
time_info = None
|
|
||||||
else:
|
|
||||||
time_info = (ttl, int(time.time()))
|
|
||||||
return self._decrypt_data(data, timestamp, time_info)
|
|
||||||
|
|
||||||
def decrypt_at_time(
|
|
||||||
self, token: bytes, ttl: int, current_time: int
|
|
||||||
) -> bytes:
|
|
||||||
if ttl is None:
|
|
||||||
raise ValueError(
|
|
||||||
"decrypt_at_time() can only be used with a non-None ttl"
|
|
||||||
)
|
|
||||||
timestamp, data = Fernet._get_unverified_token_data(token)
|
|
||||||
return self._decrypt_data(data, timestamp, (ttl, current_time))
|
|
||||||
|
|
||||||
def extract_timestamp(self, token: bytes) -> int:
|
|
||||||
timestamp, data = Fernet._get_unverified_token_data(token)
|
|
||||||
# Verify the token was not tampered with.
|
|
||||||
self._verify_signature(data)
|
|
||||||
return timestamp
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_unverified_token_data(token: bytes) -> typing.Tuple[int, bytes]:
|
|
||||||
utils._check_bytes("token", token)
|
|
||||||
try:
|
|
||||||
data = base64.urlsafe_b64decode(token)
|
|
||||||
except (TypeError, binascii.Error):
|
|
||||||
raise InvalidToken
|
|
||||||
|
|
||||||
if not data or data[0] != 0x80:
|
|
||||||
raise InvalidToken
|
|
||||||
|
|
||||||
if len(data) < 9:
|
|
||||||
raise InvalidToken
|
|
||||||
|
|
||||||
timestamp = int.from_bytes(data[1:9], byteorder="big")
|
|
||||||
return timestamp, data
|
|
||||||
|
|
||||||
def _verify_signature(self, data: bytes) -> None:
|
|
||||||
h = HMAC(self._signing_key, hashes.SHA256())
|
|
||||||
h.update(data[:-32])
|
|
||||||
try:
|
|
||||||
h.verify(data[-32:])
|
|
||||||
except InvalidSignature:
|
|
||||||
raise InvalidToken
|
|
||||||
|
|
||||||
def _decrypt_data(
|
|
||||||
self,
|
|
||||||
data: bytes,
|
|
||||||
timestamp: int,
|
|
||||||
time_info: typing.Optional[typing.Tuple[int, int]],
|
|
||||||
) -> bytes:
|
|
||||||
if time_info is not None:
|
|
||||||
ttl, current_time = time_info
|
|
||||||
if timestamp + ttl < current_time:
|
|
||||||
raise InvalidToken
|
|
||||||
|
|
||||||
if current_time + _MAX_CLOCK_SKEW < timestamp:
|
|
||||||
raise InvalidToken
|
|
||||||
|
|
||||||
self._verify_signature(data)
|
|
||||||
|
|
||||||
iv = data[9:25]
|
|
||||||
ciphertext = data[25:-32]
|
|
||||||
decryptor = Cipher(
|
|
||||||
algorithms.AES(self._encryption_key), modes.CBC(iv)
|
|
||||||
).decryptor()
|
|
||||||
plaintext_padded = decryptor.update(ciphertext)
|
|
||||||
try:
|
|
||||||
plaintext_padded += decryptor.finalize()
|
|
||||||
except ValueError:
|
|
||||||
raise InvalidToken
|
|
||||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
|
||||||
|
|
||||||
unpadded = unpadder.update(plaintext_padded)
|
|
||||||
try:
|
|
||||||
unpadded += unpadder.finalize()
|
|
||||||
except ValueError:
|
|
||||||
raise InvalidToken
|
|
||||||
return unpadded
|
|
||||||
|
|
||||||
|
|
||||||
class MultiFernet:
|
|
||||||
def __init__(self, fernets: typing.Iterable[Fernet]):
|
|
||||||
fernets = list(fernets)
|
|
||||||
if not fernets:
|
|
||||||
raise ValueError(
|
|
||||||
"MultiFernet requires at least one Fernet instance"
|
|
||||||
)
|
|
||||||
self._fernets = fernets
|
|
||||||
|
|
||||||
def encrypt(self, msg: bytes) -> bytes:
|
|
||||||
return self.encrypt_at_time(msg, int(time.time()))
|
|
||||||
|
|
||||||
def encrypt_at_time(self, msg: bytes, current_time: int) -> bytes:
|
|
||||||
return self._fernets[0].encrypt_at_time(msg, current_time)
|
|
||||||
|
|
||||||
def rotate(self, msg: bytes) -> bytes:
|
|
||||||
timestamp, data = Fernet._get_unverified_token_data(msg)
|
|
||||||
for f in self._fernets:
|
|
||||||
try:
|
|
||||||
p = f._decrypt_data(data, timestamp, None)
|
|
||||||
break
|
|
||||||
except InvalidToken:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise InvalidToken
|
|
||||||
|
|
||||||
iv = os.urandom(16)
|
|
||||||
return self._fernets[0]._encrypt_from_parts(p, timestamp, iv)
|
|
||||||
|
|
||||||
def decrypt(self, msg: bytes, ttl: typing.Optional[int] = None) -> bytes:
|
|
||||||
for f in self._fernets:
|
|
||||||
try:
|
|
||||||
return f.decrypt(msg, ttl)
|
|
||||||
except InvalidToken:
|
|
||||||
pass
|
|
||||||
raise InvalidToken
|
|
||||||
|
|
||||||
def decrypt_at_time(
|
|
||||||
self, msg: bytes, ttl: int, current_time: int
|
|
||||||
) -> bytes:
|
|
||||||
for f in self._fernets:
|
|
||||||
try:
|
|
||||||
return f.decrypt_at_time(msg, ttl, current_time)
|
|
||||||
except InvalidToken:
|
|
||||||
pass
|
|
||||||
raise InvalidToken
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
"""
|
|
||||||
Hazardous Materials
|
|
||||||
|
|
||||||
This is a "Hazardous Materials" module. You should ONLY use it if you're
|
|
||||||
100% absolutely sure that you know what you're doing because this module
|
|
||||||
is full of land mines, dragons, and dinosaurs with laser guns.
|
|
||||||
"""
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import hashes
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectIdentifier:
|
|
||||||
def __init__(self, dotted_string: str) -> None:
|
|
||||||
self._dotted_string = dotted_string
|
|
||||||
|
|
||||||
nodes = self._dotted_string.split(".")
|
|
||||||
intnodes = []
|
|
||||||
|
|
||||||
# There must be at least 2 nodes, the first node must be 0..2, and
|
|
||||||
# if less than 2, the second node cannot have a value outside the
|
|
||||||
# range 0..39. All nodes must be integers.
|
|
||||||
for node in nodes:
|
|
||||||
try:
|
|
||||||
node_value = int(node, 10)
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError(
|
|
||||||
f"Malformed OID: {dotted_string} (non-integer nodes)"
|
|
||||||
)
|
|
||||||
if node_value < 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Malformed OID: {dotted_string} (negative-integer nodes)"
|
|
||||||
)
|
|
||||||
intnodes.append(node_value)
|
|
||||||
|
|
||||||
if len(nodes) < 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"Malformed OID: {dotted_string} "
|
|
||||||
"(insufficient number of nodes)"
|
|
||||||
)
|
|
||||||
|
|
||||||
if intnodes[0] > 2:
|
|
||||||
raise ValueError(
|
|
||||||
f"Malformed OID: {dotted_string} "
|
|
||||||
"(first node outside valid range)"
|
|
||||||
)
|
|
||||||
|
|
||||||
if intnodes[0] < 2 and intnodes[1] >= 40:
|
|
||||||
raise ValueError(
|
|
||||||
f"Malformed OID: {dotted_string} "
|
|
||||||
"(second node outside valid range)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, ObjectIdentifier):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return self.dotted_string == other.dotted_string
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "<ObjectIdentifier(oid={}, name={})>".format(
|
|
||||||
self.dotted_string, self._name
|
|
||||||
)
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return hash(self.dotted_string)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _name(self) -> str:
|
|
||||||
return _OID_NAMES.get(self, "Unknown OID")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dotted_string(self) -> str:
|
|
||||||
return self._dotted_string
|
|
||||||
|
|
||||||
|
|
||||||
class ExtensionOID:
|
|
||||||
SUBJECT_DIRECTORY_ATTRIBUTES = ObjectIdentifier("2.5.29.9")
|
|
||||||
SUBJECT_KEY_IDENTIFIER = ObjectIdentifier("2.5.29.14")
|
|
||||||
KEY_USAGE = ObjectIdentifier("2.5.29.15")
|
|
||||||
SUBJECT_ALTERNATIVE_NAME = ObjectIdentifier("2.5.29.17")
|
|
||||||
ISSUER_ALTERNATIVE_NAME = ObjectIdentifier("2.5.29.18")
|
|
||||||
BASIC_CONSTRAINTS = ObjectIdentifier("2.5.29.19")
|
|
||||||
NAME_CONSTRAINTS = ObjectIdentifier("2.5.29.30")
|
|
||||||
CRL_DISTRIBUTION_POINTS = ObjectIdentifier("2.5.29.31")
|
|
||||||
CERTIFICATE_POLICIES = ObjectIdentifier("2.5.29.32")
|
|
||||||
POLICY_MAPPINGS = ObjectIdentifier("2.5.29.33")
|
|
||||||
AUTHORITY_KEY_IDENTIFIER = ObjectIdentifier("2.5.29.35")
|
|
||||||
POLICY_CONSTRAINTS = ObjectIdentifier("2.5.29.36")
|
|
||||||
EXTENDED_KEY_USAGE = ObjectIdentifier("2.5.29.37")
|
|
||||||
FRESHEST_CRL = ObjectIdentifier("2.5.29.46")
|
|
||||||
INHIBIT_ANY_POLICY = ObjectIdentifier("2.5.29.54")
|
|
||||||
ISSUING_DISTRIBUTION_POINT = ObjectIdentifier("2.5.29.28")
|
|
||||||
AUTHORITY_INFORMATION_ACCESS = ObjectIdentifier("1.3.6.1.5.5.7.1.1")
|
|
||||||
SUBJECT_INFORMATION_ACCESS = ObjectIdentifier("1.3.6.1.5.5.7.1.11")
|
|
||||||
OCSP_NO_CHECK = ObjectIdentifier("1.3.6.1.5.5.7.48.1.5")
|
|
||||||
TLS_FEATURE = ObjectIdentifier("1.3.6.1.5.5.7.1.24")
|
|
||||||
CRL_NUMBER = ObjectIdentifier("2.5.29.20")
|
|
||||||
DELTA_CRL_INDICATOR = ObjectIdentifier("2.5.29.27")
|
|
||||||
PRECERT_SIGNED_CERTIFICATE_TIMESTAMPS = ObjectIdentifier(
|
|
||||||
"1.3.6.1.4.1.11129.2.4.2"
|
|
||||||
)
|
|
||||||
PRECERT_POISON = ObjectIdentifier("1.3.6.1.4.1.11129.2.4.3")
|
|
||||||
SIGNED_CERTIFICATE_TIMESTAMPS = ObjectIdentifier("1.3.6.1.4.1.11129.2.4.5")
|
|
||||||
|
|
||||||
|
|
||||||
class OCSPExtensionOID:
|
|
||||||
NONCE = ObjectIdentifier("1.3.6.1.5.5.7.48.1.2")
|
|
||||||
|
|
||||||
|
|
||||||
class CRLEntryExtensionOID:
|
|
||||||
CERTIFICATE_ISSUER = ObjectIdentifier("2.5.29.29")
|
|
||||||
CRL_REASON = ObjectIdentifier("2.5.29.21")
|
|
||||||
INVALIDITY_DATE = ObjectIdentifier("2.5.29.24")
|
|
||||||
|
|
||||||
|
|
||||||
class NameOID:
|
|
||||||
COMMON_NAME = ObjectIdentifier("2.5.4.3")
|
|
||||||
COUNTRY_NAME = ObjectIdentifier("2.5.4.6")
|
|
||||||
LOCALITY_NAME = ObjectIdentifier("2.5.4.7")
|
|
||||||
STATE_OR_PROVINCE_NAME = ObjectIdentifier("2.5.4.8")
|
|
||||||
STREET_ADDRESS = ObjectIdentifier("2.5.4.9")
|
|
||||||
ORGANIZATION_NAME = ObjectIdentifier("2.5.4.10")
|
|
||||||
ORGANIZATIONAL_UNIT_NAME = ObjectIdentifier("2.5.4.11")
|
|
||||||
SERIAL_NUMBER = ObjectIdentifier("2.5.4.5")
|
|
||||||
SURNAME = ObjectIdentifier("2.5.4.4")
|
|
||||||
GIVEN_NAME = ObjectIdentifier("2.5.4.42")
|
|
||||||
TITLE = ObjectIdentifier("2.5.4.12")
|
|
||||||
GENERATION_QUALIFIER = ObjectIdentifier("2.5.4.44")
|
|
||||||
X500_UNIQUE_IDENTIFIER = ObjectIdentifier("2.5.4.45")
|
|
||||||
DN_QUALIFIER = ObjectIdentifier("2.5.4.46")
|
|
||||||
PSEUDONYM = ObjectIdentifier("2.5.4.65")
|
|
||||||
USER_ID = ObjectIdentifier("0.9.2342.19200300.100.1.1")
|
|
||||||
DOMAIN_COMPONENT = ObjectIdentifier("0.9.2342.19200300.100.1.25")
|
|
||||||
EMAIL_ADDRESS = ObjectIdentifier("1.2.840.113549.1.9.1")
|
|
||||||
JURISDICTION_COUNTRY_NAME = ObjectIdentifier("1.3.6.1.4.1.311.60.2.1.3")
|
|
||||||
JURISDICTION_LOCALITY_NAME = ObjectIdentifier("1.3.6.1.4.1.311.60.2.1.1")
|
|
||||||
JURISDICTION_STATE_OR_PROVINCE_NAME = ObjectIdentifier(
|
|
||||||
"1.3.6.1.4.1.311.60.2.1.2"
|
|
||||||
)
|
|
||||||
BUSINESS_CATEGORY = ObjectIdentifier("2.5.4.15")
|
|
||||||
POSTAL_ADDRESS = ObjectIdentifier("2.5.4.16")
|
|
||||||
POSTAL_CODE = ObjectIdentifier("2.5.4.17")
|
|
||||||
INN = ObjectIdentifier("1.2.643.3.131.1.1")
|
|
||||||
OGRN = ObjectIdentifier("1.2.643.100.1")
|
|
||||||
SNILS = ObjectIdentifier("1.2.643.100.3")
|
|
||||||
UNSTRUCTURED_NAME = ObjectIdentifier("1.2.840.113549.1.9.2")
|
|
||||||
|
|
||||||
|
|
||||||
class SignatureAlgorithmOID:
|
|
||||||
RSA_WITH_MD5 = ObjectIdentifier("1.2.840.113549.1.1.4")
|
|
||||||
RSA_WITH_SHA1 = ObjectIdentifier("1.2.840.113549.1.1.5")
|
|
||||||
# This is an alternate OID for RSA with SHA1 that is occasionally seen
|
|
||||||
_RSA_WITH_SHA1 = ObjectIdentifier("1.3.14.3.2.29")
|
|
||||||
RSA_WITH_SHA224 = ObjectIdentifier("1.2.840.113549.1.1.14")
|
|
||||||
RSA_WITH_SHA256 = ObjectIdentifier("1.2.840.113549.1.1.11")
|
|
||||||
RSA_WITH_SHA384 = ObjectIdentifier("1.2.840.113549.1.1.12")
|
|
||||||
RSA_WITH_SHA512 = ObjectIdentifier("1.2.840.113549.1.1.13")
|
|
||||||
RSA_WITH_SHA3_224 = ObjectIdentifier("2.16.840.1.101.3.4.3.13")
|
|
||||||
RSA_WITH_SHA3_256 = ObjectIdentifier("2.16.840.1.101.3.4.3.14")
|
|
||||||
RSA_WITH_SHA3_384 = ObjectIdentifier("2.16.840.1.101.3.4.3.15")
|
|
||||||
RSA_WITH_SHA3_512 = ObjectIdentifier("2.16.840.1.101.3.4.3.16")
|
|
||||||
RSASSA_PSS = ObjectIdentifier("1.2.840.113549.1.1.10")
|
|
||||||
ECDSA_WITH_SHA1 = ObjectIdentifier("1.2.840.10045.4.1")
|
|
||||||
ECDSA_WITH_SHA224 = ObjectIdentifier("1.2.840.10045.4.3.1")
|
|
||||||
ECDSA_WITH_SHA256 = ObjectIdentifier("1.2.840.10045.4.3.2")
|
|
||||||
ECDSA_WITH_SHA384 = ObjectIdentifier("1.2.840.10045.4.3.3")
|
|
||||||
ECDSA_WITH_SHA512 = ObjectIdentifier("1.2.840.10045.4.3.4")
|
|
||||||
ECDSA_WITH_SHA3_224 = ObjectIdentifier("2.16.840.1.101.3.4.3.9")
|
|
||||||
ECDSA_WITH_SHA3_256 = ObjectIdentifier("2.16.840.1.101.3.4.3.10")
|
|
||||||
ECDSA_WITH_SHA3_384 = ObjectIdentifier("2.16.840.1.101.3.4.3.11")
|
|
||||||
ECDSA_WITH_SHA3_512 = ObjectIdentifier("2.16.840.1.101.3.4.3.12")
|
|
||||||
DSA_WITH_SHA1 = ObjectIdentifier("1.2.840.10040.4.3")
|
|
||||||
DSA_WITH_SHA224 = ObjectIdentifier("2.16.840.1.101.3.4.3.1")
|
|
||||||
DSA_WITH_SHA256 = ObjectIdentifier("2.16.840.1.101.3.4.3.2")
|
|
||||||
DSA_WITH_SHA384 = ObjectIdentifier("2.16.840.1.101.3.4.3.3")
|
|
||||||
DSA_WITH_SHA512 = ObjectIdentifier("2.16.840.1.101.3.4.3.4")
|
|
||||||
ED25519 = ObjectIdentifier("1.3.101.112")
|
|
||||||
ED448 = ObjectIdentifier("1.3.101.113")
|
|
||||||
GOSTR3411_94_WITH_3410_2001 = ObjectIdentifier("1.2.643.2.2.3")
|
|
||||||
GOSTR3410_2012_WITH_3411_2012_256 = ObjectIdentifier("1.2.643.7.1.1.3.2")
|
|
||||||
GOSTR3410_2012_WITH_3411_2012_512 = ObjectIdentifier("1.2.643.7.1.1.3.3")
|
|
||||||
|
|
||||||
|
|
||||||
_SIG_OIDS_TO_HASH: typing.Dict[
|
|
||||||
ObjectIdentifier, typing.Optional[hashes.HashAlgorithm]
|
|
||||||
] = {
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_MD5: hashes.MD5(),
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA1: hashes.SHA1(),
|
|
||||||
SignatureAlgorithmOID._RSA_WITH_SHA1: hashes.SHA1(),
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA224: hashes.SHA224(),
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA256: hashes.SHA256(),
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA384: hashes.SHA384(),
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA512: hashes.SHA512(),
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA1: hashes.SHA1(),
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA224: hashes.SHA224(),
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA256: hashes.SHA256(),
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA384: hashes.SHA384(),
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA512: hashes.SHA512(),
|
|
||||||
SignatureAlgorithmOID.DSA_WITH_SHA1: hashes.SHA1(),
|
|
||||||
SignatureAlgorithmOID.DSA_WITH_SHA224: hashes.SHA224(),
|
|
||||||
SignatureAlgorithmOID.DSA_WITH_SHA256: hashes.SHA256(),
|
|
||||||
SignatureAlgorithmOID.ED25519: None,
|
|
||||||
SignatureAlgorithmOID.ED448: None,
|
|
||||||
SignatureAlgorithmOID.GOSTR3411_94_WITH_3410_2001: None,
|
|
||||||
SignatureAlgorithmOID.GOSTR3410_2012_WITH_3411_2012_256: None,
|
|
||||||
SignatureAlgorithmOID.GOSTR3410_2012_WITH_3411_2012_512: None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ExtendedKeyUsageOID:
|
|
||||||
SERVER_AUTH = ObjectIdentifier("1.3.6.1.5.5.7.3.1")
|
|
||||||
CLIENT_AUTH = ObjectIdentifier("1.3.6.1.5.5.7.3.2")
|
|
||||||
CODE_SIGNING = ObjectIdentifier("1.3.6.1.5.5.7.3.3")
|
|
||||||
EMAIL_PROTECTION = ObjectIdentifier("1.3.6.1.5.5.7.3.4")
|
|
||||||
TIME_STAMPING = ObjectIdentifier("1.3.6.1.5.5.7.3.8")
|
|
||||||
OCSP_SIGNING = ObjectIdentifier("1.3.6.1.5.5.7.3.9")
|
|
||||||
ANY_EXTENDED_KEY_USAGE = ObjectIdentifier("2.5.29.37.0")
|
|
||||||
SMARTCARD_LOGON = ObjectIdentifier("1.3.6.1.4.1.311.20.2.2")
|
|
||||||
KERBEROS_PKINIT_KDC = ObjectIdentifier("1.3.6.1.5.2.3.5")
|
|
||||||
IPSEC_IKE = ObjectIdentifier("1.3.6.1.5.5.7.3.17")
|
|
||||||
|
|
||||||
|
|
||||||
class AuthorityInformationAccessOID:
|
|
||||||
CA_ISSUERS = ObjectIdentifier("1.3.6.1.5.5.7.48.2")
|
|
||||||
OCSP = ObjectIdentifier("1.3.6.1.5.5.7.48.1")
|
|
||||||
|
|
||||||
|
|
||||||
class SubjectInformationAccessOID:
|
|
||||||
CA_REPOSITORY = ObjectIdentifier("1.3.6.1.5.5.7.48.5")
|
|
||||||
|
|
||||||
|
|
||||||
class CertificatePoliciesOID:
|
|
||||||
CPS_QUALIFIER = ObjectIdentifier("1.3.6.1.5.5.7.2.1")
|
|
||||||
CPS_USER_NOTICE = ObjectIdentifier("1.3.6.1.5.5.7.2.2")
|
|
||||||
ANY_POLICY = ObjectIdentifier("2.5.29.32.0")
|
|
||||||
|
|
||||||
|
|
||||||
class AttributeOID:
|
|
||||||
CHALLENGE_PASSWORD = ObjectIdentifier("1.2.840.113549.1.9.7")
|
|
||||||
UNSTRUCTURED_NAME = ObjectIdentifier("1.2.840.113549.1.9.2")
|
|
||||||
|
|
||||||
|
|
||||||
_OID_NAMES = {
|
|
||||||
NameOID.COMMON_NAME: "commonName",
|
|
||||||
NameOID.COUNTRY_NAME: "countryName",
|
|
||||||
NameOID.LOCALITY_NAME: "localityName",
|
|
||||||
NameOID.STATE_OR_PROVINCE_NAME: "stateOrProvinceName",
|
|
||||||
NameOID.STREET_ADDRESS: "streetAddress",
|
|
||||||
NameOID.ORGANIZATION_NAME: "organizationName",
|
|
||||||
NameOID.ORGANIZATIONAL_UNIT_NAME: "organizationalUnitName",
|
|
||||||
NameOID.SERIAL_NUMBER: "serialNumber",
|
|
||||||
NameOID.SURNAME: "surname",
|
|
||||||
NameOID.GIVEN_NAME: "givenName",
|
|
||||||
NameOID.TITLE: "title",
|
|
||||||
NameOID.GENERATION_QUALIFIER: "generationQualifier",
|
|
||||||
NameOID.X500_UNIQUE_IDENTIFIER: "x500UniqueIdentifier",
|
|
||||||
NameOID.DN_QUALIFIER: "dnQualifier",
|
|
||||||
NameOID.PSEUDONYM: "pseudonym",
|
|
||||||
NameOID.USER_ID: "userID",
|
|
||||||
NameOID.DOMAIN_COMPONENT: "domainComponent",
|
|
||||||
NameOID.EMAIL_ADDRESS: "emailAddress",
|
|
||||||
NameOID.JURISDICTION_COUNTRY_NAME: "jurisdictionCountryName",
|
|
||||||
NameOID.JURISDICTION_LOCALITY_NAME: "jurisdictionLocalityName",
|
|
||||||
NameOID.JURISDICTION_STATE_OR_PROVINCE_NAME: (
|
|
||||||
"jurisdictionStateOrProvinceName"
|
|
||||||
),
|
|
||||||
NameOID.BUSINESS_CATEGORY: "businessCategory",
|
|
||||||
NameOID.POSTAL_ADDRESS: "postalAddress",
|
|
||||||
NameOID.POSTAL_CODE: "postalCode",
|
|
||||||
NameOID.INN: "INN",
|
|
||||||
NameOID.OGRN: "OGRN",
|
|
||||||
NameOID.SNILS: "SNILS",
|
|
||||||
NameOID.UNSTRUCTURED_NAME: "unstructuredName",
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_MD5: "md5WithRSAEncryption",
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA1: "sha1WithRSAEncryption",
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA224: "sha224WithRSAEncryption",
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA256: "sha256WithRSAEncryption",
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA384: "sha384WithRSAEncryption",
|
|
||||||
SignatureAlgorithmOID.RSA_WITH_SHA512: "sha512WithRSAEncryption",
|
|
||||||
SignatureAlgorithmOID.RSASSA_PSS: "RSASSA-PSS",
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA1: "ecdsa-with-SHA1",
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA224: "ecdsa-with-SHA224",
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA256: "ecdsa-with-SHA256",
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA384: "ecdsa-with-SHA384",
|
|
||||||
SignatureAlgorithmOID.ECDSA_WITH_SHA512: "ecdsa-with-SHA512",
|
|
||||||
SignatureAlgorithmOID.DSA_WITH_SHA1: "dsa-with-sha1",
|
|
||||||
SignatureAlgorithmOID.DSA_WITH_SHA224: "dsa-with-sha224",
|
|
||||||
SignatureAlgorithmOID.DSA_WITH_SHA256: "dsa-with-sha256",
|
|
||||||
SignatureAlgorithmOID.ED25519: "ed25519",
|
|
||||||
SignatureAlgorithmOID.ED448: "ed448",
|
|
||||||
SignatureAlgorithmOID.GOSTR3411_94_WITH_3410_2001: (
|
|
||||||
"GOST R 34.11-94 with GOST R 34.10-2001"
|
|
||||||
),
|
|
||||||
SignatureAlgorithmOID.GOSTR3410_2012_WITH_3411_2012_256: (
|
|
||||||
"GOST R 34.10-2012 with GOST R 34.11-2012 (256 bit)"
|
|
||||||
),
|
|
||||||
SignatureAlgorithmOID.GOSTR3410_2012_WITH_3411_2012_512: (
|
|
||||||
"GOST R 34.10-2012 with GOST R 34.11-2012 (512 bit)"
|
|
||||||
),
|
|
||||||
ExtendedKeyUsageOID.SERVER_AUTH: "serverAuth",
|
|
||||||
ExtendedKeyUsageOID.CLIENT_AUTH: "clientAuth",
|
|
||||||
ExtendedKeyUsageOID.CODE_SIGNING: "codeSigning",
|
|
||||||
ExtendedKeyUsageOID.EMAIL_PROTECTION: "emailProtection",
|
|
||||||
ExtendedKeyUsageOID.TIME_STAMPING: "timeStamping",
|
|
||||||
ExtendedKeyUsageOID.OCSP_SIGNING: "OCSPSigning",
|
|
||||||
ExtendedKeyUsageOID.SMARTCARD_LOGON: "msSmartcardLogin",
|
|
||||||
ExtendedKeyUsageOID.KERBEROS_PKINIT_KDC: "pkInitKDC",
|
|
||||||
ExtensionOID.SUBJECT_DIRECTORY_ATTRIBUTES: "subjectDirectoryAttributes",
|
|
||||||
ExtensionOID.SUBJECT_KEY_IDENTIFIER: "subjectKeyIdentifier",
|
|
||||||
ExtensionOID.KEY_USAGE: "keyUsage",
|
|
||||||
ExtensionOID.SUBJECT_ALTERNATIVE_NAME: "subjectAltName",
|
|
||||||
ExtensionOID.ISSUER_ALTERNATIVE_NAME: "issuerAltName",
|
|
||||||
ExtensionOID.BASIC_CONSTRAINTS: "basicConstraints",
|
|
||||||
ExtensionOID.PRECERT_SIGNED_CERTIFICATE_TIMESTAMPS: (
|
|
||||||
"signedCertificateTimestampList"
|
|
||||||
),
|
|
||||||
ExtensionOID.SIGNED_CERTIFICATE_TIMESTAMPS: (
|
|
||||||
"signedCertificateTimestampList"
|
|
||||||
),
|
|
||||||
ExtensionOID.PRECERT_POISON: "ctPoison",
|
|
||||||
CRLEntryExtensionOID.CRL_REASON: "cRLReason",
|
|
||||||
CRLEntryExtensionOID.INVALIDITY_DATE: "invalidityDate",
|
|
||||||
CRLEntryExtensionOID.CERTIFICATE_ISSUER: "certificateIssuer",
|
|
||||||
ExtensionOID.NAME_CONSTRAINTS: "nameConstraints",
|
|
||||||
ExtensionOID.CRL_DISTRIBUTION_POINTS: "cRLDistributionPoints",
|
|
||||||
ExtensionOID.CERTIFICATE_POLICIES: "certificatePolicies",
|
|
||||||
ExtensionOID.POLICY_MAPPINGS: "policyMappings",
|
|
||||||
ExtensionOID.AUTHORITY_KEY_IDENTIFIER: "authorityKeyIdentifier",
|
|
||||||
ExtensionOID.POLICY_CONSTRAINTS: "policyConstraints",
|
|
||||||
ExtensionOID.EXTENDED_KEY_USAGE: "extendedKeyUsage",
|
|
||||||
ExtensionOID.FRESHEST_CRL: "freshestCRL",
|
|
||||||
ExtensionOID.INHIBIT_ANY_POLICY: "inhibitAnyPolicy",
|
|
||||||
ExtensionOID.ISSUING_DISTRIBUTION_POINT: ("issuingDistributionPoint"),
|
|
||||||
ExtensionOID.AUTHORITY_INFORMATION_ACCESS: "authorityInfoAccess",
|
|
||||||
ExtensionOID.SUBJECT_INFORMATION_ACCESS: "subjectInfoAccess",
|
|
||||||
ExtensionOID.OCSP_NO_CHECK: "OCSPNoCheck",
|
|
||||||
ExtensionOID.CRL_NUMBER: "cRLNumber",
|
|
||||||
ExtensionOID.DELTA_CRL_INDICATOR: "deltaCRLIndicator",
|
|
||||||
ExtensionOID.TLS_FEATURE: "TLSFeature",
|
|
||||||
AuthorityInformationAccessOID.OCSP: "OCSP",
|
|
||||||
AuthorityInformationAccessOID.CA_ISSUERS: "caIssuers",
|
|
||||||
SubjectInformationAccessOID.CA_REPOSITORY: "caRepository",
|
|
||||||
CertificatePoliciesOID.CPS_QUALIFIER: "id-qt-cps",
|
|
||||||
CertificatePoliciesOID.CPS_USER_NOTICE: "id-qt-unotice",
|
|
||||||
OCSPExtensionOID.NONCE: "OCSPNonce",
|
|
||||||
AttributeOID.CHALLENGE_PASSWORD: "challengePassword",
|
|
||||||
}
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def default_backend() -> Any:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
|
|
||||||
return backend
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["backend"]
|
|
||||||
@@ -1,251 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import InvalidTag
|
|
||||||
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
from cryptography.hazmat.primitives.ciphers.aead import (
|
|
||||||
AESCCM,
|
|
||||||
AESGCM,
|
|
||||||
AESOCB3,
|
|
||||||
AESSIV,
|
|
||||||
ChaCha20Poly1305,
|
|
||||||
)
|
|
||||||
|
|
||||||
_AEAD_TYPES = typing.Union[
|
|
||||||
AESCCM, AESGCM, AESOCB3, AESSIV, ChaCha20Poly1305
|
|
||||||
]
|
|
||||||
|
|
||||||
_ENCRYPT = 1
|
|
||||||
_DECRYPT = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _aead_cipher_name(cipher: "_AEAD_TYPES") -> bytes:
|
|
||||||
from cryptography.hazmat.primitives.ciphers.aead import (
|
|
||||||
AESCCM,
|
|
||||||
AESGCM,
|
|
||||||
AESOCB3,
|
|
||||||
AESSIV,
|
|
||||||
ChaCha20Poly1305,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(cipher, ChaCha20Poly1305):
|
|
||||||
return b"chacha20-poly1305"
|
|
||||||
elif isinstance(cipher, AESCCM):
|
|
||||||
return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii")
|
|
||||||
elif isinstance(cipher, AESOCB3):
|
|
||||||
return f"aes-{len(cipher._key) * 8}-ocb".encode("ascii")
|
|
||||||
elif isinstance(cipher, AESSIV):
|
|
||||||
return f"aes-{len(cipher._key) * 8 // 2}-siv".encode("ascii")
|
|
||||||
else:
|
|
||||||
assert isinstance(cipher, AESGCM)
|
|
||||||
return f"aes-{len(cipher._key) * 8}-gcm".encode("ascii")
|
|
||||||
|
|
||||||
|
|
||||||
def _evp_cipher(cipher_name: bytes, backend: "Backend"):
|
|
||||||
if cipher_name.endswith(b"-siv"):
|
|
||||||
evp_cipher = backend._lib.EVP_CIPHER_fetch(
|
|
||||||
backend._ffi.NULL,
|
|
||||||
cipher_name,
|
|
||||||
backend._ffi.NULL,
|
|
||||||
)
|
|
||||||
backend.openssl_assert(evp_cipher != backend._ffi.NULL)
|
|
||||||
evp_cipher = backend._ffi.gc(evp_cipher, backend._lib.EVP_CIPHER_free)
|
|
||||||
else:
|
|
||||||
evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name)
|
|
||||||
backend.openssl_assert(evp_cipher != backend._ffi.NULL)
|
|
||||||
|
|
||||||
return evp_cipher
|
|
||||||
|
|
||||||
|
|
||||||
def _aead_setup(
|
|
||||||
backend: "Backend",
|
|
||||||
cipher_name: bytes,
|
|
||||||
key: bytes,
|
|
||||||
nonce: bytes,
|
|
||||||
tag: typing.Optional[bytes],
|
|
||||||
tag_len: int,
|
|
||||||
operation: int,
|
|
||||||
):
|
|
||||||
evp_cipher = _evp_cipher(cipher_name, backend)
|
|
||||||
ctx = backend._lib.EVP_CIPHER_CTX_new()
|
|
||||||
ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free)
|
|
||||||
res = backend._lib.EVP_CipherInit_ex(
|
|
||||||
ctx,
|
|
||||||
evp_cipher,
|
|
||||||
backend._ffi.NULL,
|
|
||||||
backend._ffi.NULL,
|
|
||||||
backend._ffi.NULL,
|
|
||||||
int(operation == _ENCRYPT),
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
res = backend._lib.EVP_CIPHER_CTX_set_key_length(ctx, len(key))
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
res = backend._lib.EVP_CIPHER_CTX_ctrl(
|
|
||||||
ctx,
|
|
||||||
backend._lib.EVP_CTRL_AEAD_SET_IVLEN,
|
|
||||||
len(nonce),
|
|
||||||
backend._ffi.NULL,
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
if operation == _DECRYPT:
|
|
||||||
assert tag is not None
|
|
||||||
res = backend._lib.EVP_CIPHER_CTX_ctrl(
|
|
||||||
ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
elif cipher_name.endswith(b"-ccm"):
|
|
||||||
res = backend._lib.EVP_CIPHER_CTX_ctrl(
|
|
||||||
ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, tag_len, backend._ffi.NULL
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
|
|
||||||
nonce_ptr = backend._ffi.from_buffer(nonce)
|
|
||||||
key_ptr = backend._ffi.from_buffer(key)
|
|
||||||
res = backend._lib.EVP_CipherInit_ex(
|
|
||||||
ctx,
|
|
||||||
backend._ffi.NULL,
|
|
||||||
backend._ffi.NULL,
|
|
||||||
key_ptr,
|
|
||||||
nonce_ptr,
|
|
||||||
int(operation == _ENCRYPT),
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
return ctx
|
|
||||||
|
|
||||||
|
|
||||||
def _set_length(backend: "Backend", ctx, data_len: int) -> None:
|
|
||||||
intptr = backend._ffi.new("int *")
|
|
||||||
res = backend._lib.EVP_CipherUpdate(
|
|
||||||
ctx, backend._ffi.NULL, intptr, backend._ffi.NULL, data_len
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
|
|
||||||
|
|
||||||
def _process_aad(backend: "Backend", ctx, associated_data: bytes) -> None:
|
|
||||||
outlen = backend._ffi.new("int *")
|
|
||||||
res = backend._lib.EVP_CipherUpdate(
|
|
||||||
ctx, backend._ffi.NULL, outlen, associated_data, len(associated_data)
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
|
|
||||||
|
|
||||||
def _process_data(backend: "Backend", ctx, data: bytes) -> bytes:
|
|
||||||
outlen = backend._ffi.new("int *")
|
|
||||||
buf = backend._ffi.new("unsigned char[]", len(data))
|
|
||||||
res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
|
|
||||||
if res == 0:
|
|
||||||
# AES SIV can error here if the data is invalid on decrypt
|
|
||||||
backend._consume_errors()
|
|
||||||
raise InvalidTag
|
|
||||||
return backend._ffi.buffer(buf, outlen[0])[:]
|
|
||||||
|
|
||||||
|
|
||||||
def _encrypt(
|
|
||||||
backend: "Backend",
|
|
||||||
cipher: "_AEAD_TYPES",
|
|
||||||
nonce: bytes,
|
|
||||||
data: bytes,
|
|
||||||
associated_data: typing.List[bytes],
|
|
||||||
tag_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV
|
|
||||||
|
|
||||||
cipher_name = _aead_cipher_name(cipher)
|
|
||||||
ctx = _aead_setup(
|
|
||||||
backend, cipher_name, cipher._key, nonce, None, tag_length, _ENCRYPT
|
|
||||||
)
|
|
||||||
# CCM requires us to pass the length of the data before processing anything
|
|
||||||
# However calling this with any other AEAD results in an error
|
|
||||||
if isinstance(cipher, AESCCM):
|
|
||||||
_set_length(backend, ctx, len(data))
|
|
||||||
|
|
||||||
for ad in associated_data:
|
|
||||||
_process_aad(backend, ctx, ad)
|
|
||||||
processed_data = _process_data(backend, ctx, data)
|
|
||||||
outlen = backend._ffi.new("int *")
|
|
||||||
# All AEADs we support besides OCB are streaming so they return nothing
|
|
||||||
# in finalization. OCB can return up to (16 byte block - 1) bytes so
|
|
||||||
# we need a buffer here too.
|
|
||||||
buf = backend._ffi.new("unsigned char[]", 16)
|
|
||||||
res = backend._lib.EVP_CipherFinal_ex(ctx, buf, outlen)
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
processed_data += backend._ffi.buffer(buf, outlen[0])[:]
|
|
||||||
tag_buf = backend._ffi.new("unsigned char[]", tag_length)
|
|
||||||
res = backend._lib.EVP_CIPHER_CTX_ctrl(
|
|
||||||
ctx, backend._lib.EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res != 0)
|
|
||||||
tag = backend._ffi.buffer(tag_buf)[:]
|
|
||||||
|
|
||||||
if isinstance(cipher, AESSIV):
|
|
||||||
# RFC 5297 defines the output as IV || C, where the tag we generate is
|
|
||||||
# the "IV" and C is the ciphertext. This is the opposite of our
|
|
||||||
# other AEADs, which are Ciphertext || Tag
|
|
||||||
backend.openssl_assert(len(tag) == 16)
|
|
||||||
return tag + processed_data
|
|
||||||
else:
|
|
||||||
return processed_data + tag
|
|
||||||
|
|
||||||
|
|
||||||
def _decrypt(
|
|
||||||
backend: "Backend",
|
|
||||||
cipher: "_AEAD_TYPES",
|
|
||||||
nonce: bytes,
|
|
||||||
data: bytes,
|
|
||||||
associated_data: typing.List[bytes],
|
|
||||||
tag_length: int,
|
|
||||||
) -> bytes:
|
|
||||||
from cryptography.hazmat.primitives.ciphers.aead import AESCCM, AESSIV
|
|
||||||
|
|
||||||
if len(data) < tag_length:
|
|
||||||
raise InvalidTag
|
|
||||||
|
|
||||||
if isinstance(cipher, AESSIV):
|
|
||||||
# RFC 5297 defines the output as IV || C, where the tag we generate is
|
|
||||||
# the "IV" and C is the ciphertext. This is the opposite of our
|
|
||||||
# other AEADs, which are Ciphertext || Tag
|
|
||||||
tag = data[:tag_length]
|
|
||||||
data = data[tag_length:]
|
|
||||||
else:
|
|
||||||
tag = data[-tag_length:]
|
|
||||||
data = data[:-tag_length]
|
|
||||||
cipher_name = _aead_cipher_name(cipher)
|
|
||||||
ctx = _aead_setup(
|
|
||||||
backend, cipher_name, cipher._key, nonce, tag, tag_length, _DECRYPT
|
|
||||||
)
|
|
||||||
# CCM requires us to pass the length of the data before processing anything
|
|
||||||
# However calling this with any other AEAD results in an error
|
|
||||||
if isinstance(cipher, AESCCM):
|
|
||||||
_set_length(backend, ctx, len(data))
|
|
||||||
|
|
||||||
for ad in associated_data:
|
|
||||||
_process_aad(backend, ctx, ad)
|
|
||||||
# CCM has a different error path if the tag doesn't match. Errors are
|
|
||||||
# raised in Update and Final is irrelevant.
|
|
||||||
if isinstance(cipher, AESCCM):
|
|
||||||
outlen = backend._ffi.new("int *")
|
|
||||||
buf = backend._ffi.new("unsigned char[]", len(data))
|
|
||||||
res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data, len(data))
|
|
||||||
if res != 1:
|
|
||||||
backend._consume_errors()
|
|
||||||
raise InvalidTag
|
|
||||||
|
|
||||||
processed_data = backend._ffi.buffer(buf, outlen[0])[:]
|
|
||||||
else:
|
|
||||||
processed_data = _process_data(backend, ctx, data)
|
|
||||||
outlen = backend._ffi.new("int *")
|
|
||||||
# OCB can return up to 15 bytes (16 byte block - 1) in finalization
|
|
||||||
buf = backend._ffi.new("unsigned char[]", 16)
|
|
||||||
res = backend._lib.EVP_CipherFinal_ex(ctx, buf, outlen)
|
|
||||||
processed_data += backend._ffi.buffer(buf, outlen[0])[:]
|
|
||||||
if res == 0:
|
|
||||||
backend._consume_errors()
|
|
||||||
raise InvalidTag
|
|
||||||
|
|
||||||
return processed_data
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,282 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import InvalidTag, UnsupportedAlgorithm, _Reasons
|
|
||||||
from cryptography.hazmat.primitives import ciphers
|
|
||||||
from cryptography.hazmat.primitives.ciphers import algorithms, modes
|
|
||||||
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
class _CipherContext:
|
|
||||||
_ENCRYPT = 1
|
|
||||||
_DECRYPT = 0
|
|
||||||
_MAX_CHUNK_SIZE = 2**30 - 1
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, backend: "Backend", cipher, mode, operation: int
|
|
||||||
) -> None:
|
|
||||||
self._backend = backend
|
|
||||||
self._cipher = cipher
|
|
||||||
self._mode = mode
|
|
||||||
self._operation = operation
|
|
||||||
self._tag: typing.Optional[bytes] = None
|
|
||||||
|
|
||||||
if isinstance(self._cipher, ciphers.BlockCipherAlgorithm):
|
|
||||||
self._block_size_bytes = self._cipher.block_size // 8
|
|
||||||
else:
|
|
||||||
self._block_size_bytes = 1
|
|
||||||
|
|
||||||
ctx = self._backend._lib.EVP_CIPHER_CTX_new()
|
|
||||||
ctx = self._backend._ffi.gc(
|
|
||||||
ctx, self._backend._lib.EVP_CIPHER_CTX_free
|
|
||||||
)
|
|
||||||
|
|
||||||
registry = self._backend._cipher_registry
|
|
||||||
try:
|
|
||||||
adapter = registry[type(cipher), type(mode)]
|
|
||||||
except KeyError:
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"cipher {} in {} mode is not supported "
|
|
||||||
"by this backend.".format(
|
|
||||||
cipher.name, mode.name if mode else mode
|
|
||||||
),
|
|
||||||
_Reasons.UNSUPPORTED_CIPHER,
|
|
||||||
)
|
|
||||||
|
|
||||||
evp_cipher = adapter(self._backend, cipher, mode)
|
|
||||||
if evp_cipher == self._backend._ffi.NULL:
|
|
||||||
msg = "cipher {0.name} ".format(cipher)
|
|
||||||
if mode is not None:
|
|
||||||
msg += "in {0.name} mode ".format(mode)
|
|
||||||
msg += (
|
|
||||||
"is not supported by this backend (Your version of OpenSSL "
|
|
||||||
"may be too old. Current version: {}.)"
|
|
||||||
).format(self._backend.openssl_version_text())
|
|
||||||
raise UnsupportedAlgorithm(msg, _Reasons.UNSUPPORTED_CIPHER)
|
|
||||||
|
|
||||||
if isinstance(mode, modes.ModeWithInitializationVector):
|
|
||||||
iv_nonce = self._backend._ffi.from_buffer(
|
|
||||||
mode.initialization_vector
|
|
||||||
)
|
|
||||||
elif isinstance(mode, modes.ModeWithTweak):
|
|
||||||
iv_nonce = self._backend._ffi.from_buffer(mode.tweak)
|
|
||||||
elif isinstance(mode, modes.ModeWithNonce):
|
|
||||||
iv_nonce = self._backend._ffi.from_buffer(mode.nonce)
|
|
||||||
elif isinstance(cipher, algorithms.ChaCha20):
|
|
||||||
iv_nonce = self._backend._ffi.from_buffer(cipher.nonce)
|
|
||||||
else:
|
|
||||||
iv_nonce = self._backend._ffi.NULL
|
|
||||||
# begin init with cipher and operation type
|
|
||||||
res = self._backend._lib.EVP_CipherInit_ex(
|
|
||||||
ctx,
|
|
||||||
evp_cipher,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
operation,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
# set the key length to handle variable key ciphers
|
|
||||||
res = self._backend._lib.EVP_CIPHER_CTX_set_key_length(
|
|
||||||
ctx, len(cipher.key)
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
if isinstance(mode, modes.GCM):
|
|
||||||
res = self._backend._lib.EVP_CIPHER_CTX_ctrl(
|
|
||||||
ctx,
|
|
||||||
self._backend._lib.EVP_CTRL_AEAD_SET_IVLEN,
|
|
||||||
len(iv_nonce),
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
if mode.tag is not None:
|
|
||||||
res = self._backend._lib.EVP_CIPHER_CTX_ctrl(
|
|
||||||
ctx,
|
|
||||||
self._backend._lib.EVP_CTRL_AEAD_SET_TAG,
|
|
||||||
len(mode.tag),
|
|
||||||
mode.tag,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
self._tag = mode.tag
|
|
||||||
|
|
||||||
# pass key/iv
|
|
||||||
res = self._backend._lib.EVP_CipherInit_ex(
|
|
||||||
ctx,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.from_buffer(cipher.key),
|
|
||||||
iv_nonce,
|
|
||||||
operation,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for XTS mode duplicate keys error
|
|
||||||
errors = self._backend._consume_errors()
|
|
||||||
lib = self._backend._lib
|
|
||||||
if res == 0 and (
|
|
||||||
(
|
|
||||||
lib.CRYPTOGRAPHY_OPENSSL_111D_OR_GREATER
|
|
||||||
and errors[0]._lib_reason_match(
|
|
||||||
lib.ERR_LIB_EVP, lib.EVP_R_XTS_DUPLICATED_KEYS
|
|
||||||
)
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
lib.Cryptography_HAS_PROVIDERS
|
|
||||||
and errors[0]._lib_reason_match(
|
|
||||||
lib.ERR_LIB_PROV, lib.PROV_R_XTS_DUPLICATED_KEYS
|
|
||||||
)
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError("In XTS mode duplicated keys are not allowed")
|
|
||||||
|
|
||||||
self._backend.openssl_assert(res != 0, errors=errors)
|
|
||||||
|
|
||||||
# We purposely disable padding here as it's handled higher up in the
|
|
||||||
# API.
|
|
||||||
self._backend._lib.EVP_CIPHER_CTX_set_padding(ctx, 0)
|
|
||||||
self._ctx = ctx
|
|
||||||
|
|
||||||
def update(self, data: bytes) -> bytes:
|
|
||||||
buf = bytearray(len(data) + self._block_size_bytes - 1)
|
|
||||||
n = self.update_into(data, buf)
|
|
||||||
return bytes(buf[:n])
|
|
||||||
|
|
||||||
def update_into(self, data: bytes, buf: bytes) -> int:
|
|
||||||
total_data_len = len(data)
|
|
||||||
if len(buf) < (total_data_len + self._block_size_bytes - 1):
|
|
||||||
raise ValueError(
|
|
||||||
"buffer must be at least {} bytes for this "
|
|
||||||
"payload".format(len(data) + self._block_size_bytes - 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
data_processed = 0
|
|
||||||
total_out = 0
|
|
||||||
outlen = self._backend._ffi.new("int *")
|
|
||||||
baseoutbuf = self._backend._ffi.from_buffer(buf)
|
|
||||||
baseinbuf = self._backend._ffi.from_buffer(data)
|
|
||||||
|
|
||||||
while data_processed != total_data_len:
|
|
||||||
outbuf = baseoutbuf + total_out
|
|
||||||
inbuf = baseinbuf + data_processed
|
|
||||||
inlen = min(self._MAX_CHUNK_SIZE, total_data_len - data_processed)
|
|
||||||
|
|
||||||
res = self._backend._lib.EVP_CipherUpdate(
|
|
||||||
self._ctx, outbuf, outlen, inbuf, inlen
|
|
||||||
)
|
|
||||||
if res == 0 and isinstance(self._mode, modes.XTS):
|
|
||||||
self._backend._consume_errors()
|
|
||||||
raise ValueError(
|
|
||||||
"In XTS mode you must supply at least a full block in the "
|
|
||||||
"first update call. For AES this is 16 bytes."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
data_processed += inlen
|
|
||||||
total_out += outlen[0]
|
|
||||||
|
|
||||||
return total_out
|
|
||||||
|
|
||||||
def finalize(self) -> bytes:
|
|
||||||
if (
|
|
||||||
self._operation == self._DECRYPT
|
|
||||||
and isinstance(self._mode, modes.ModeWithAuthenticationTag)
|
|
||||||
and self.tag is None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Authentication tag must be provided when decrypting."
|
|
||||||
)
|
|
||||||
|
|
||||||
buf = self._backend._ffi.new("unsigned char[]", self._block_size_bytes)
|
|
||||||
outlen = self._backend._ffi.new("int *")
|
|
||||||
res = self._backend._lib.EVP_CipherFinal_ex(self._ctx, buf, outlen)
|
|
||||||
if res == 0:
|
|
||||||
errors = self._backend._consume_errors()
|
|
||||||
|
|
||||||
if not errors and isinstance(self._mode, modes.GCM):
|
|
||||||
raise InvalidTag
|
|
||||||
|
|
||||||
lib = self._backend._lib
|
|
||||||
self._backend.openssl_assert(
|
|
||||||
errors[0]._lib_reason_match(
|
|
||||||
lib.ERR_LIB_EVP,
|
|
||||||
lib.EVP_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH,
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
lib.Cryptography_HAS_PROVIDERS
|
|
||||||
and errors[0]._lib_reason_match(
|
|
||||||
lib.ERR_LIB_PROV,
|
|
||||||
lib.PROV_R_WRONG_FINAL_BLOCK_LENGTH,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
lib.CRYPTOGRAPHY_IS_BORINGSSL
|
|
||||||
and errors[0].reason
|
|
||||||
== lib.CIPHER_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH
|
|
||||||
),
|
|
||||||
errors=errors,
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
"The length of the provided data is not a multiple of "
|
|
||||||
"the block length."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(self._mode, modes.GCM)
|
|
||||||
and self._operation == self._ENCRYPT
|
|
||||||
):
|
|
||||||
tag_buf = self._backend._ffi.new(
|
|
||||||
"unsigned char[]", self._block_size_bytes
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_CIPHER_CTX_ctrl(
|
|
||||||
self._ctx,
|
|
||||||
self._backend._lib.EVP_CTRL_AEAD_GET_TAG,
|
|
||||||
self._block_size_bytes,
|
|
||||||
tag_buf,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
self._tag = self._backend._ffi.buffer(tag_buf)[:]
|
|
||||||
|
|
||||||
res = self._backend._lib.EVP_CIPHER_CTX_reset(self._ctx)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
return self._backend._ffi.buffer(buf)[: outlen[0]]
|
|
||||||
|
|
||||||
def finalize_with_tag(self, tag: bytes) -> bytes:
|
|
||||||
tag_len = len(tag)
|
|
||||||
if tag_len < self._mode._min_tag_length:
|
|
||||||
raise ValueError(
|
|
||||||
"Authentication tag must be {} bytes or longer.".format(
|
|
||||||
self._mode._min_tag_length
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif tag_len > self._block_size_bytes:
|
|
||||||
raise ValueError(
|
|
||||||
"Authentication tag cannot be more than {} bytes.".format(
|
|
||||||
self._block_size_bytes
|
|
||||||
)
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_CIPHER_CTX_ctrl(
|
|
||||||
self._ctx, self._backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
self._tag = tag
|
|
||||||
return self.finalize()
|
|
||||||
|
|
||||||
def authenticate_additional_data(self, data: bytes) -> None:
|
|
||||||
outlen = self._backend._ffi.new("int *")
|
|
||||||
res = self._backend._lib.EVP_CipherUpdate(
|
|
||||||
self._ctx,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
outlen,
|
|
||||||
self._backend._ffi.from_buffer(data),
|
|
||||||
len(data),
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tag(self) -> typing.Optional[bytes]:
|
|
||||||
return self._tag
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import (
|
|
||||||
InvalidSignature,
|
|
||||||
UnsupportedAlgorithm,
|
|
||||||
_Reasons,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.primitives import constant_time
|
|
||||||
from cryptography.hazmat.primitives.ciphers.modes import CBC
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.primitives import ciphers
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
class _CMACContext:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
backend: "Backend",
|
|
||||||
algorithm: "ciphers.BlockCipherAlgorithm",
|
|
||||||
ctx=None,
|
|
||||||
) -> None:
|
|
||||||
if not backend.cmac_algorithm_supported(algorithm):
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"This backend does not support CMAC.",
|
|
||||||
_Reasons.UNSUPPORTED_CIPHER,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._backend = backend
|
|
||||||
self._key = algorithm.key
|
|
||||||
self._algorithm = algorithm
|
|
||||||
self._output_length = algorithm.block_size // 8
|
|
||||||
|
|
||||||
if ctx is None:
|
|
||||||
registry = self._backend._cipher_registry
|
|
||||||
adapter = registry[type(algorithm), CBC]
|
|
||||||
|
|
||||||
evp_cipher = adapter(self._backend, algorithm, CBC)
|
|
||||||
|
|
||||||
ctx = self._backend._lib.CMAC_CTX_new()
|
|
||||||
|
|
||||||
self._backend.openssl_assert(ctx != self._backend._ffi.NULL)
|
|
||||||
ctx = self._backend._ffi.gc(ctx, self._backend._lib.CMAC_CTX_free)
|
|
||||||
|
|
||||||
key_ptr = self._backend._ffi.from_buffer(self._key)
|
|
||||||
res = self._backend._lib.CMAC_Init(
|
|
||||||
ctx,
|
|
||||||
key_ptr,
|
|
||||||
len(self._key),
|
|
||||||
evp_cipher,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
self._ctx = ctx
|
|
||||||
|
|
||||||
def update(self, data: bytes) -> None:
|
|
||||||
res = self._backend._lib.CMAC_Update(self._ctx, data, len(data))
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
def finalize(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new("unsigned char[]", self._output_length)
|
|
||||||
length = self._backend._ffi.new("size_t *", self._output_length)
|
|
||||||
res = self._backend._lib.CMAC_Final(self._ctx, buf, length)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
self._ctx = None
|
|
||||||
|
|
||||||
return self._backend._ffi.buffer(buf)[:]
|
|
||||||
|
|
||||||
def copy(self) -> "_CMACContext":
|
|
||||||
copied_ctx = self._backend._lib.CMAC_CTX_new()
|
|
||||||
copied_ctx = self._backend._ffi.gc(
|
|
||||||
copied_ctx, self._backend._lib.CMAC_CTX_free
|
|
||||||
)
|
|
||||||
res = self._backend._lib.CMAC_CTX_copy(copied_ctx, self._ctx)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
return _CMACContext(self._backend, self._algorithm, ctx=copied_ctx)
|
|
||||||
|
|
||||||
def verify(self, signature: bytes) -> None:
|
|
||||||
digest = self.finalize()
|
|
||||||
if not constant_time.bytes_eq(digest, signature):
|
|
||||||
raise InvalidSignature("Signature did not match digest.")
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
from cryptography import x509
|
|
||||||
|
|
||||||
# CRLReason ::= ENUMERATED {
|
|
||||||
# unspecified (0),
|
|
||||||
# keyCompromise (1),
|
|
||||||
# cACompromise (2),
|
|
||||||
# affiliationChanged (3),
|
|
||||||
# superseded (4),
|
|
||||||
# cessationOfOperation (5),
|
|
||||||
# certificateHold (6),
|
|
||||||
# -- value 7 is not used
|
|
||||||
# removeFromCRL (8),
|
|
||||||
# privilegeWithdrawn (9),
|
|
||||||
# aACompromise (10) }
|
|
||||||
_CRL_ENTRY_REASON_ENUM_TO_CODE = {
|
|
||||||
x509.ReasonFlags.unspecified: 0,
|
|
||||||
x509.ReasonFlags.key_compromise: 1,
|
|
||||||
x509.ReasonFlags.ca_compromise: 2,
|
|
||||||
x509.ReasonFlags.affiliation_changed: 3,
|
|
||||||
x509.ReasonFlags.superseded: 4,
|
|
||||||
x509.ReasonFlags.cessation_of_operation: 5,
|
|
||||||
x509.ReasonFlags.certificate_hold: 6,
|
|
||||||
x509.ReasonFlags.remove_from_crl: 8,
|
|
||||||
x509.ReasonFlags.privilege_withdrawn: 9,
|
|
||||||
x509.ReasonFlags.aa_compromise: 10,
|
|
||||||
}
|
|
||||||
@@ -1,318 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import dh
|
|
||||||
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
def _dh_params_dup(dh_cdata, backend: "Backend"):
|
|
||||||
lib = backend._lib
|
|
||||||
ffi = backend._ffi
|
|
||||||
|
|
||||||
param_cdata = lib.DHparams_dup(dh_cdata)
|
|
||||||
backend.openssl_assert(param_cdata != ffi.NULL)
|
|
||||||
param_cdata = ffi.gc(param_cdata, lib.DH_free)
|
|
||||||
if lib.CRYPTOGRAPHY_IS_LIBRESSL:
|
|
||||||
# In libressl DHparams_dup don't copy q
|
|
||||||
q = ffi.new("BIGNUM **")
|
|
||||||
lib.DH_get0_pqg(dh_cdata, ffi.NULL, q, ffi.NULL)
|
|
||||||
q_dup = lib.BN_dup(q[0])
|
|
||||||
res = lib.DH_set0_pqg(param_cdata, ffi.NULL, q_dup, ffi.NULL)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
return param_cdata
|
|
||||||
|
|
||||||
|
|
||||||
def _dh_cdata_to_parameters(dh_cdata, backend: "Backend") -> "_DHParameters":
|
|
||||||
param_cdata = _dh_params_dup(dh_cdata, backend)
|
|
||||||
return _DHParameters(backend, param_cdata)
|
|
||||||
|
|
||||||
|
|
||||||
class _DHParameters(dh.DHParameters):
|
|
||||||
def __init__(self, backend: "Backend", dh_cdata):
|
|
||||||
self._backend = backend
|
|
||||||
self._dh_cdata = dh_cdata
|
|
||||||
|
|
||||||
def parameter_numbers(self) -> dh.DHParameterNumbers:
|
|
||||||
p = self._backend._ffi.new("BIGNUM **")
|
|
||||||
g = self._backend._ffi.new("BIGNUM **")
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DH_get0_pqg(self._dh_cdata, p, q, g)
|
|
||||||
self._backend.openssl_assert(p[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(g[0] != self._backend._ffi.NULL)
|
|
||||||
q_val: typing.Optional[int]
|
|
||||||
if q[0] == self._backend._ffi.NULL:
|
|
||||||
q_val = None
|
|
||||||
else:
|
|
||||||
q_val = self._backend._bn_to_int(q[0])
|
|
||||||
return dh.DHParameterNumbers(
|
|
||||||
p=self._backend._bn_to_int(p[0]),
|
|
||||||
g=self._backend._bn_to_int(g[0]),
|
|
||||||
q=q_val,
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_private_key(self) -> dh.DHPrivateKey:
|
|
||||||
return self._backend.generate_dh_private_key(self)
|
|
||||||
|
|
||||||
def parameter_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.ParameterFormat,
|
|
||||||
) -> bytes:
|
|
||||||
if encoding is serialization.Encoding.OpenSSH:
|
|
||||||
raise TypeError("OpenSSH encoding is not supported")
|
|
||||||
|
|
||||||
if format is not serialization.ParameterFormat.PKCS3:
|
|
||||||
raise ValueError("Only PKCS3 serialization is supported")
|
|
||||||
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DH_get0_pqg(
|
|
||||||
self._dh_cdata, self._backend._ffi.NULL, q, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
q[0] != self._backend._ffi.NULL
|
|
||||||
and not self._backend._lib.Cryptography_HAS_EVP_PKEY_DHX
|
|
||||||
):
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"DH X9.42 serialization is not supported",
|
|
||||||
_Reasons.UNSUPPORTED_SERIALIZATION,
|
|
||||||
)
|
|
||||||
|
|
||||||
if encoding is serialization.Encoding.PEM:
|
|
||||||
if q[0] != self._backend._ffi.NULL:
|
|
||||||
write_bio = self._backend._lib.PEM_write_bio_DHxparams
|
|
||||||
else:
|
|
||||||
write_bio = self._backend._lib.PEM_write_bio_DHparams
|
|
||||||
elif encoding is serialization.Encoding.DER:
|
|
||||||
if q[0] != self._backend._ffi.NULL:
|
|
||||||
write_bio = self._backend._lib.Cryptography_i2d_DHxparams_bio
|
|
||||||
else:
|
|
||||||
write_bio = self._backend._lib.i2d_DHparams_bio
|
|
||||||
else:
|
|
||||||
raise TypeError("encoding must be an item from the Encoding enum")
|
|
||||||
|
|
||||||
bio = self._backend._create_mem_bio_gc()
|
|
||||||
res = write_bio(bio, self._dh_cdata)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
return self._backend._read_mem_bio(bio)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_dh_num_bits(backend, dh_cdata) -> int:
|
|
||||||
p = backend._ffi.new("BIGNUM **")
|
|
||||||
backend._lib.DH_get0_pqg(dh_cdata, p, backend._ffi.NULL, backend._ffi.NULL)
|
|
||||||
backend.openssl_assert(p[0] != backend._ffi.NULL)
|
|
||||||
return backend._lib.BN_num_bits(p[0])
|
|
||||||
|
|
||||||
|
|
||||||
class _DHPrivateKey(dh.DHPrivateKey):
|
|
||||||
def __init__(self, backend: "Backend", dh_cdata, evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._dh_cdata = dh_cdata
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
self._key_size_bytes = self._backend._lib.DH_size(dh_cdata)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key_size(self) -> int:
|
|
||||||
return _get_dh_num_bits(self._backend, self._dh_cdata)
|
|
||||||
|
|
||||||
def private_numbers(self) -> dh.DHPrivateNumbers:
|
|
||||||
p = self._backend._ffi.new("BIGNUM **")
|
|
||||||
g = self._backend._ffi.new("BIGNUM **")
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DH_get0_pqg(self._dh_cdata, p, q, g)
|
|
||||||
self._backend.openssl_assert(p[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(g[0] != self._backend._ffi.NULL)
|
|
||||||
if q[0] == self._backend._ffi.NULL:
|
|
||||||
q_val = None
|
|
||||||
else:
|
|
||||||
q_val = self._backend._bn_to_int(q[0])
|
|
||||||
pub_key = self._backend._ffi.new("BIGNUM **")
|
|
||||||
priv_key = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DH_get0_key(self._dh_cdata, pub_key, priv_key)
|
|
||||||
self._backend.openssl_assert(pub_key[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(priv_key[0] != self._backend._ffi.NULL)
|
|
||||||
return dh.DHPrivateNumbers(
|
|
||||||
public_numbers=dh.DHPublicNumbers(
|
|
||||||
parameter_numbers=dh.DHParameterNumbers(
|
|
||||||
p=self._backend._bn_to_int(p[0]),
|
|
||||||
g=self._backend._bn_to_int(g[0]),
|
|
||||||
q=q_val,
|
|
||||||
),
|
|
||||||
y=self._backend._bn_to_int(pub_key[0]),
|
|
||||||
),
|
|
||||||
x=self._backend._bn_to_int(priv_key[0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
def exchange(self, peer_public_key: dh.DHPublicKey) -> bytes:
|
|
||||||
if not isinstance(peer_public_key, _DHPublicKey):
|
|
||||||
raise TypeError("peer_public_key must be a DHPublicKey")
|
|
||||||
|
|
||||||
ctx = self._backend._lib.EVP_PKEY_CTX_new(
|
|
||||||
self._evp_pkey, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(ctx != self._backend._ffi.NULL)
|
|
||||||
ctx = self._backend._ffi.gc(ctx, self._backend._lib.EVP_PKEY_CTX_free)
|
|
||||||
res = self._backend._lib.EVP_PKEY_derive_init(ctx)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
res = self._backend._lib.EVP_PKEY_derive_set_peer(
|
|
||||||
ctx, peer_public_key._evp_pkey
|
|
||||||
)
|
|
||||||
# Invalid kex errors here in OpenSSL 3.0 because checks were moved
|
|
||||||
# to EVP_PKEY_derive_set_peer
|
|
||||||
self._exchange_assert(res == 1)
|
|
||||||
keylen = self._backend._ffi.new("size_t *")
|
|
||||||
res = self._backend._lib.EVP_PKEY_derive(
|
|
||||||
ctx, self._backend._ffi.NULL, keylen
|
|
||||||
)
|
|
||||||
# Invalid kex errors here in OpenSSL < 3
|
|
||||||
self._exchange_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(keylen[0] > 0)
|
|
||||||
buf = self._backend._ffi.new("unsigned char[]", keylen[0])
|
|
||||||
res = self._backend._lib.EVP_PKEY_derive(ctx, buf, keylen)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
key = self._backend._ffi.buffer(buf, keylen[0])[:]
|
|
||||||
pad = self._key_size_bytes - len(key)
|
|
||||||
|
|
||||||
if pad > 0:
|
|
||||||
key = (b"\x00" * pad) + key
|
|
||||||
|
|
||||||
return key
|
|
||||||
|
|
||||||
def _exchange_assert(self, ok: bool) -> None:
|
|
||||||
if not ok:
|
|
||||||
errors_with_text = self._backend._consume_errors_with_text()
|
|
||||||
raise ValueError(
|
|
||||||
"Error computing shared key.",
|
|
||||||
errors_with_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
def public_key(self) -> dh.DHPublicKey:
|
|
||||||
dh_cdata = _dh_params_dup(self._dh_cdata, self._backend)
|
|
||||||
pub_key = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DH_get0_key(
|
|
||||||
self._dh_cdata, pub_key, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(pub_key[0] != self._backend._ffi.NULL)
|
|
||||||
pub_key_dup = self._backend._lib.BN_dup(pub_key[0])
|
|
||||||
self._backend.openssl_assert(pub_key_dup != self._backend._ffi.NULL)
|
|
||||||
|
|
||||||
res = self._backend._lib.DH_set0_key(
|
|
||||||
dh_cdata, pub_key_dup, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
evp_pkey = self._backend._dh_cdata_to_evp_pkey(dh_cdata)
|
|
||||||
return _DHPublicKey(self._backend, dh_cdata, evp_pkey)
|
|
||||||
|
|
||||||
def parameters(self) -> dh.DHParameters:
|
|
||||||
return _dh_cdata_to_parameters(self._dh_cdata, self._backend)
|
|
||||||
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
if format is not serialization.PrivateFormat.PKCS8:
|
|
||||||
raise ValueError(
|
|
||||||
"DH private keys support only PKCS8 serialization"
|
|
||||||
)
|
|
||||||
if not self._backend._lib.Cryptography_HAS_EVP_PKEY_DHX:
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DH_get0_pqg(
|
|
||||||
self._dh_cdata,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
q,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
)
|
|
||||||
if q[0] != self._backend._ffi.NULL:
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"DH X9.42 serialization is not supported",
|
|
||||||
_Reasons.UNSUPPORTED_SERIALIZATION,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._backend._private_key_bytes(
|
|
||||||
encoding,
|
|
||||||
format,
|
|
||||||
encryption_algorithm,
|
|
||||||
self,
|
|
||||||
self._evp_pkey,
|
|
||||||
self._dh_cdata,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _DHPublicKey(dh.DHPublicKey):
|
|
||||||
def __init__(self, backend: "Backend", dh_cdata, evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._dh_cdata = dh_cdata
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
self._key_size_bits = _get_dh_num_bits(self._backend, self._dh_cdata)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key_size(self) -> int:
|
|
||||||
return self._key_size_bits
|
|
||||||
|
|
||||||
def public_numbers(self) -> dh.DHPublicNumbers:
|
|
||||||
p = self._backend._ffi.new("BIGNUM **")
|
|
||||||
g = self._backend._ffi.new("BIGNUM **")
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DH_get0_pqg(self._dh_cdata, p, q, g)
|
|
||||||
self._backend.openssl_assert(p[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(g[0] != self._backend._ffi.NULL)
|
|
||||||
if q[0] == self._backend._ffi.NULL:
|
|
||||||
q_val = None
|
|
||||||
else:
|
|
||||||
q_val = self._backend._bn_to_int(q[0])
|
|
||||||
pub_key = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DH_get0_key(
|
|
||||||
self._dh_cdata, pub_key, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(pub_key[0] != self._backend._ffi.NULL)
|
|
||||||
return dh.DHPublicNumbers(
|
|
||||||
parameter_numbers=dh.DHParameterNumbers(
|
|
||||||
p=self._backend._bn_to_int(p[0]),
|
|
||||||
g=self._backend._bn_to_int(g[0]),
|
|
||||||
q=q_val,
|
|
||||||
),
|
|
||||||
y=self._backend._bn_to_int(pub_key[0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
def parameters(self) -> dh.DHParameters:
|
|
||||||
return _dh_cdata_to_parameters(self._dh_cdata, self._backend)
|
|
||||||
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
if format is not serialization.PublicFormat.SubjectPublicKeyInfo:
|
|
||||||
raise ValueError(
|
|
||||||
"DH public keys support only "
|
|
||||||
"SubjectPublicKeyInfo serialization"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self._backend._lib.Cryptography_HAS_EVP_PKEY_DHX:
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DH_get0_pqg(
|
|
||||||
self._dh_cdata,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
q,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
)
|
|
||||||
if q[0] != self._backend._ffi.NULL:
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"DH X9.42 serialization is not supported",
|
|
||||||
_Reasons.UNSUPPORTED_SERIALIZATION,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._backend._public_key_bytes(
|
|
||||||
encoding, format, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import InvalidSignature
|
|
||||||
from cryptography.hazmat.backends.openssl.utils import (
|
|
||||||
_calculate_digest_and_algorithm,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import (
|
|
||||||
dsa,
|
|
||||||
utils as asym_utils,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
def _dsa_sig_sign(
|
|
||||||
backend: "Backend", private_key: "_DSAPrivateKey", data: bytes
|
|
||||||
) -> bytes:
|
|
||||||
sig_buf_len = backend._lib.DSA_size(private_key._dsa_cdata)
|
|
||||||
sig_buf = backend._ffi.new("unsigned char[]", sig_buf_len)
|
|
||||||
buflen = backend._ffi.new("unsigned int *")
|
|
||||||
|
|
||||||
# The first parameter passed to DSA_sign is unused by OpenSSL but
|
|
||||||
# must be an integer.
|
|
||||||
res = backend._lib.DSA_sign(
|
|
||||||
0, data, len(data), sig_buf, buflen, private_key._dsa_cdata
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
backend.openssl_assert(buflen[0])
|
|
||||||
|
|
||||||
return backend._ffi.buffer(sig_buf)[: buflen[0]]
|
|
||||||
|
|
||||||
|
|
||||||
def _dsa_sig_verify(
|
|
||||||
backend: "Backend",
|
|
||||||
public_key: "_DSAPublicKey",
|
|
||||||
signature: bytes,
|
|
||||||
data: bytes,
|
|
||||||
) -> None:
|
|
||||||
# The first parameter passed to DSA_verify is unused by OpenSSL but
|
|
||||||
# must be an integer.
|
|
||||||
res = backend._lib.DSA_verify(
|
|
||||||
0, data, len(data), signature, len(signature), public_key._dsa_cdata
|
|
||||||
)
|
|
||||||
|
|
||||||
if res != 1:
|
|
||||||
backend._consume_errors()
|
|
||||||
raise InvalidSignature
|
|
||||||
|
|
||||||
|
|
||||||
class _DSAParameters(dsa.DSAParameters):
|
|
||||||
def __init__(self, backend: "Backend", dsa_cdata):
|
|
||||||
self._backend = backend
|
|
||||||
self._dsa_cdata = dsa_cdata
|
|
||||||
|
|
||||||
def parameter_numbers(self) -> dsa.DSAParameterNumbers:
|
|
||||||
p = self._backend._ffi.new("BIGNUM **")
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
g = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DSA_get0_pqg(self._dsa_cdata, p, q, g)
|
|
||||||
self._backend.openssl_assert(p[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(q[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(g[0] != self._backend._ffi.NULL)
|
|
||||||
return dsa.DSAParameterNumbers(
|
|
||||||
p=self._backend._bn_to_int(p[0]),
|
|
||||||
q=self._backend._bn_to_int(q[0]),
|
|
||||||
g=self._backend._bn_to_int(g[0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_private_key(self) -> dsa.DSAPrivateKey:
|
|
||||||
return self._backend.generate_dsa_private_key(self)
|
|
||||||
|
|
||||||
|
|
||||||
class _DSAPrivateKey(dsa.DSAPrivateKey):
|
|
||||||
_key_size: int
|
|
||||||
|
|
||||||
def __init__(self, backend: "Backend", dsa_cdata, evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._dsa_cdata = dsa_cdata
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
p = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DSA_get0_pqg(
|
|
||||||
dsa_cdata, p, self._backend._ffi.NULL, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(p[0] != backend._ffi.NULL)
|
|
||||||
self._key_size = self._backend._lib.BN_num_bits(p[0])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key_size(self) -> int:
|
|
||||||
return self._key_size
|
|
||||||
|
|
||||||
def private_numbers(self) -> dsa.DSAPrivateNumbers:
|
|
||||||
p = self._backend._ffi.new("BIGNUM **")
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
g = self._backend._ffi.new("BIGNUM **")
|
|
||||||
pub_key = self._backend._ffi.new("BIGNUM **")
|
|
||||||
priv_key = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DSA_get0_pqg(self._dsa_cdata, p, q, g)
|
|
||||||
self._backend.openssl_assert(p[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(q[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(g[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend._lib.DSA_get0_key(self._dsa_cdata, pub_key, priv_key)
|
|
||||||
self._backend.openssl_assert(pub_key[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(priv_key[0] != self._backend._ffi.NULL)
|
|
||||||
return dsa.DSAPrivateNumbers(
|
|
||||||
public_numbers=dsa.DSAPublicNumbers(
|
|
||||||
parameter_numbers=dsa.DSAParameterNumbers(
|
|
||||||
p=self._backend._bn_to_int(p[0]),
|
|
||||||
q=self._backend._bn_to_int(q[0]),
|
|
||||||
g=self._backend._bn_to_int(g[0]),
|
|
||||||
),
|
|
||||||
y=self._backend._bn_to_int(pub_key[0]),
|
|
||||||
),
|
|
||||||
x=self._backend._bn_to_int(priv_key[0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
def public_key(self) -> dsa.DSAPublicKey:
|
|
||||||
dsa_cdata = self._backend._lib.DSAparams_dup(self._dsa_cdata)
|
|
||||||
self._backend.openssl_assert(dsa_cdata != self._backend._ffi.NULL)
|
|
||||||
dsa_cdata = self._backend._ffi.gc(
|
|
||||||
dsa_cdata, self._backend._lib.DSA_free
|
|
||||||
)
|
|
||||||
pub_key = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DSA_get0_key(
|
|
||||||
self._dsa_cdata, pub_key, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(pub_key[0] != self._backend._ffi.NULL)
|
|
||||||
pub_key_dup = self._backend._lib.BN_dup(pub_key[0])
|
|
||||||
res = self._backend._lib.DSA_set0_key(
|
|
||||||
dsa_cdata, pub_key_dup, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
evp_pkey = self._backend._dsa_cdata_to_evp_pkey(dsa_cdata)
|
|
||||||
return _DSAPublicKey(self._backend, dsa_cdata, evp_pkey)
|
|
||||||
|
|
||||||
def parameters(self) -> dsa.DSAParameters:
|
|
||||||
dsa_cdata = self._backend._lib.DSAparams_dup(self._dsa_cdata)
|
|
||||||
self._backend.openssl_assert(dsa_cdata != self._backend._ffi.NULL)
|
|
||||||
dsa_cdata = self._backend._ffi.gc(
|
|
||||||
dsa_cdata, self._backend._lib.DSA_free
|
|
||||||
)
|
|
||||||
return _DSAParameters(self._backend, dsa_cdata)
|
|
||||||
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
return self._backend._private_key_bytes(
|
|
||||||
encoding,
|
|
||||||
format,
|
|
||||||
encryption_algorithm,
|
|
||||||
self,
|
|
||||||
self._evp_pkey,
|
|
||||||
self._dsa_cdata,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sign(
|
|
||||||
self,
|
|
||||||
data: bytes,
|
|
||||||
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
|
|
||||||
) -> bytes:
|
|
||||||
data, _ = _calculate_digest_and_algorithm(data, algorithm)
|
|
||||||
return _dsa_sig_sign(self._backend, self, data)
|
|
||||||
|
|
||||||
|
|
||||||
class _DSAPublicKey(dsa.DSAPublicKey):
|
|
||||||
_key_size: int
|
|
||||||
|
|
||||||
def __init__(self, backend: "Backend", dsa_cdata, evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._dsa_cdata = dsa_cdata
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
p = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DSA_get0_pqg(
|
|
||||||
dsa_cdata, p, self._backend._ffi.NULL, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(p[0] != backend._ffi.NULL)
|
|
||||||
self._key_size = self._backend._lib.BN_num_bits(p[0])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key_size(self) -> int:
|
|
||||||
return self._key_size
|
|
||||||
|
|
||||||
def public_numbers(self) -> dsa.DSAPublicNumbers:
|
|
||||||
p = self._backend._ffi.new("BIGNUM **")
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
g = self._backend._ffi.new("BIGNUM **")
|
|
||||||
pub_key = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.DSA_get0_pqg(self._dsa_cdata, p, q, g)
|
|
||||||
self._backend.openssl_assert(p[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(q[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(g[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend._lib.DSA_get0_key(
|
|
||||||
self._dsa_cdata, pub_key, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(pub_key[0] != self._backend._ffi.NULL)
|
|
||||||
return dsa.DSAPublicNumbers(
|
|
||||||
parameter_numbers=dsa.DSAParameterNumbers(
|
|
||||||
p=self._backend._bn_to_int(p[0]),
|
|
||||||
q=self._backend._bn_to_int(q[0]),
|
|
||||||
g=self._backend._bn_to_int(g[0]),
|
|
||||||
),
|
|
||||||
y=self._backend._bn_to_int(pub_key[0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
def parameters(self) -> dsa.DSAParameters:
|
|
||||||
dsa_cdata = self._backend._lib.DSAparams_dup(self._dsa_cdata)
|
|
||||||
dsa_cdata = self._backend._ffi.gc(
|
|
||||||
dsa_cdata, self._backend._lib.DSA_free
|
|
||||||
)
|
|
||||||
return _DSAParameters(self._backend, dsa_cdata)
|
|
||||||
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
return self._backend._public_key_bytes(
|
|
||||||
encoding, format, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def verify(
|
|
||||||
self,
|
|
||||||
signature: bytes,
|
|
||||||
data: bytes,
|
|
||||||
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
|
|
||||||
) -> None:
|
|
||||||
data, _ = _calculate_digest_and_algorithm(data, algorithm)
|
|
||||||
return _dsa_sig_verify(self._backend, self, signature, data)
|
|
||||||
@@ -1,315 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import (
|
|
||||||
InvalidSignature,
|
|
||||||
UnsupportedAlgorithm,
|
|
||||||
_Reasons,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.backends.openssl.utils import (
|
|
||||||
_calculate_digest_and_algorithm,
|
|
||||||
_evp_pkey_derive,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import ec
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
def _check_signature_algorithm(
|
|
||||||
signature_algorithm: ec.EllipticCurveSignatureAlgorithm,
|
|
||||||
) -> None:
|
|
||||||
if not isinstance(signature_algorithm, ec.ECDSA):
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"Unsupported elliptic curve signature algorithm.",
|
|
||||||
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ec_key_curve_sn(backend: "Backend", ec_key) -> str:
|
|
||||||
group = backend._lib.EC_KEY_get0_group(ec_key)
|
|
||||||
backend.openssl_assert(group != backend._ffi.NULL)
|
|
||||||
|
|
||||||
nid = backend._lib.EC_GROUP_get_curve_name(group)
|
|
||||||
# The following check is to find EC keys with unnamed curves and raise
|
|
||||||
# an error for now.
|
|
||||||
if nid == backend._lib.NID_undef:
|
|
||||||
raise ValueError(
|
|
||||||
"ECDSA keys with explicit parameters are unsupported at this time"
|
|
||||||
)
|
|
||||||
|
|
||||||
# This is like the above check, but it also catches the case where you
|
|
||||||
# explicitly encoded a curve with the same parameters as a named curve.
|
|
||||||
# Don't do that.
|
|
||||||
if (
|
|
||||||
not backend._lib.CRYPTOGRAPHY_IS_LIBRESSL
|
|
||||||
and backend._lib.EC_GROUP_get_asn1_flag(group) == 0
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"ECDSA keys with explicit parameters are unsupported at this time"
|
|
||||||
)
|
|
||||||
|
|
||||||
curve_name = backend._lib.OBJ_nid2sn(nid)
|
|
||||||
backend.openssl_assert(curve_name != backend._ffi.NULL)
|
|
||||||
|
|
||||||
sn = backend._ffi.string(curve_name).decode("ascii")
|
|
||||||
return sn
|
|
||||||
|
|
||||||
|
|
||||||
def _mark_asn1_named_ec_curve(backend: "Backend", ec_cdata):
|
|
||||||
"""
|
|
||||||
Set the named curve flag on the EC_KEY. This causes OpenSSL to
|
|
||||||
serialize EC keys along with their curve OID which makes
|
|
||||||
deserialization easier.
|
|
||||||
"""
|
|
||||||
|
|
||||||
backend._lib.EC_KEY_set_asn1_flag(
|
|
||||||
ec_cdata, backend._lib.OPENSSL_EC_NAMED_CURVE
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_key_infinity(backend: "Backend", ec_cdata) -> None:
|
|
||||||
point = backend._lib.EC_KEY_get0_public_key(ec_cdata)
|
|
||||||
backend.openssl_assert(point != backend._ffi.NULL)
|
|
||||||
group = backend._lib.EC_KEY_get0_group(ec_cdata)
|
|
||||||
backend.openssl_assert(group != backend._ffi.NULL)
|
|
||||||
if backend._lib.EC_POINT_is_at_infinity(group, point):
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot load an EC public key where the point is at infinity"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sn_to_elliptic_curve(backend: "Backend", sn: str) -> ec.EllipticCurve:
|
|
||||||
try:
|
|
||||||
return ec._CURVE_TYPES[sn]()
|
|
||||||
except KeyError:
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"{} is not a supported elliptic curve".format(sn),
|
|
||||||
_Reasons.UNSUPPORTED_ELLIPTIC_CURVE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ecdsa_sig_sign(
|
|
||||||
backend: "Backend", private_key: "_EllipticCurvePrivateKey", data: bytes
|
|
||||||
) -> bytes:
|
|
||||||
max_size = backend._lib.ECDSA_size(private_key._ec_key)
|
|
||||||
backend.openssl_assert(max_size > 0)
|
|
||||||
|
|
||||||
sigbuf = backend._ffi.new("unsigned char[]", max_size)
|
|
||||||
siglen_ptr = backend._ffi.new("unsigned int[]", 1)
|
|
||||||
res = backend._lib.ECDSA_sign(
|
|
||||||
0, data, len(data), sigbuf, siglen_ptr, private_key._ec_key
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
return backend._ffi.buffer(sigbuf)[: siglen_ptr[0]]
|
|
||||||
|
|
||||||
|
|
||||||
def _ecdsa_sig_verify(
|
|
||||||
backend: "Backend",
|
|
||||||
public_key: "_EllipticCurvePublicKey",
|
|
||||||
signature: bytes,
|
|
||||||
data: bytes,
|
|
||||||
) -> None:
|
|
||||||
res = backend._lib.ECDSA_verify(
|
|
||||||
0, data, len(data), signature, len(signature), public_key._ec_key
|
|
||||||
)
|
|
||||||
if res != 1:
|
|
||||||
backend._consume_errors()
|
|
||||||
raise InvalidSignature
|
|
||||||
|
|
||||||
|
|
||||||
class _EllipticCurvePrivateKey(ec.EllipticCurvePrivateKey):
|
|
||||||
def __init__(self, backend: "Backend", ec_key_cdata, evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._ec_key = ec_key_cdata
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
sn = _ec_key_curve_sn(backend, ec_key_cdata)
|
|
||||||
self._curve = _sn_to_elliptic_curve(backend, sn)
|
|
||||||
_mark_asn1_named_ec_curve(backend, ec_key_cdata)
|
|
||||||
_check_key_infinity(backend, ec_key_cdata)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def curve(self) -> ec.EllipticCurve:
|
|
||||||
return self._curve
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key_size(self) -> int:
|
|
||||||
return self.curve.key_size
|
|
||||||
|
|
||||||
def exchange(
|
|
||||||
self, algorithm: ec.ECDH, peer_public_key: ec.EllipticCurvePublicKey
|
|
||||||
) -> bytes:
|
|
||||||
if not (
|
|
||||||
self._backend.elliptic_curve_exchange_algorithm_supported(
|
|
||||||
algorithm, self.curve
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"This backend does not support the ECDH algorithm.",
|
|
||||||
_Reasons.UNSUPPORTED_EXCHANGE_ALGORITHM,
|
|
||||||
)
|
|
||||||
|
|
||||||
if peer_public_key.curve.name != self.curve.name:
|
|
||||||
raise ValueError(
|
|
||||||
"peer_public_key and self are not on the same curve"
|
|
||||||
)
|
|
||||||
|
|
||||||
return _evp_pkey_derive(self._backend, self._evp_pkey, peer_public_key)
|
|
||||||
|
|
||||||
def public_key(self) -> ec.EllipticCurvePublicKey:
|
|
||||||
group = self._backend._lib.EC_KEY_get0_group(self._ec_key)
|
|
||||||
self._backend.openssl_assert(group != self._backend._ffi.NULL)
|
|
||||||
|
|
||||||
curve_nid = self._backend._lib.EC_GROUP_get_curve_name(group)
|
|
||||||
public_ec_key = self._backend._ec_key_new_by_curve_nid(curve_nid)
|
|
||||||
|
|
||||||
point = self._backend._lib.EC_KEY_get0_public_key(self._ec_key)
|
|
||||||
self._backend.openssl_assert(point != self._backend._ffi.NULL)
|
|
||||||
|
|
||||||
res = self._backend._lib.EC_KEY_set_public_key(public_ec_key, point)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
evp_pkey = self._backend._ec_cdata_to_evp_pkey(public_ec_key)
|
|
||||||
|
|
||||||
return _EllipticCurvePublicKey(self._backend, public_ec_key, evp_pkey)
|
|
||||||
|
|
||||||
def private_numbers(self) -> ec.EllipticCurvePrivateNumbers:
|
|
||||||
bn = self._backend._lib.EC_KEY_get0_private_key(self._ec_key)
|
|
||||||
private_value = self._backend._bn_to_int(bn)
|
|
||||||
return ec.EllipticCurvePrivateNumbers(
|
|
||||||
private_value=private_value,
|
|
||||||
public_numbers=self.public_key().public_numbers(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
return self._backend._private_key_bytes(
|
|
||||||
encoding,
|
|
||||||
format,
|
|
||||||
encryption_algorithm,
|
|
||||||
self,
|
|
||||||
self._evp_pkey,
|
|
||||||
self._ec_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sign(
|
|
||||||
self,
|
|
||||||
data: bytes,
|
|
||||||
signature_algorithm: ec.EllipticCurveSignatureAlgorithm,
|
|
||||||
) -> bytes:
|
|
||||||
_check_signature_algorithm(signature_algorithm)
|
|
||||||
data, _ = _calculate_digest_and_algorithm(
|
|
||||||
data,
|
|
||||||
signature_algorithm.algorithm,
|
|
||||||
)
|
|
||||||
return _ecdsa_sig_sign(self._backend, self, data)
|
|
||||||
|
|
||||||
|
|
||||||
class _EllipticCurvePublicKey(ec.EllipticCurvePublicKey):
|
|
||||||
def __init__(self, backend: "Backend", ec_key_cdata, evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._ec_key = ec_key_cdata
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
sn = _ec_key_curve_sn(backend, ec_key_cdata)
|
|
||||||
self._curve = _sn_to_elliptic_curve(backend, sn)
|
|
||||||
_mark_asn1_named_ec_curve(backend, ec_key_cdata)
|
|
||||||
_check_key_infinity(backend, ec_key_cdata)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def curve(self) -> ec.EllipticCurve:
|
|
||||||
return self._curve
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key_size(self) -> int:
|
|
||||||
return self.curve.key_size
|
|
||||||
|
|
||||||
def public_numbers(self) -> ec.EllipticCurvePublicNumbers:
|
|
||||||
get_func, group = self._backend._ec_key_determine_group_get_func(
|
|
||||||
self._ec_key
|
|
||||||
)
|
|
||||||
point = self._backend._lib.EC_KEY_get0_public_key(self._ec_key)
|
|
||||||
self._backend.openssl_assert(point != self._backend._ffi.NULL)
|
|
||||||
|
|
||||||
with self._backend._tmp_bn_ctx() as bn_ctx:
|
|
||||||
bn_x = self._backend._lib.BN_CTX_get(bn_ctx)
|
|
||||||
bn_y = self._backend._lib.BN_CTX_get(bn_ctx)
|
|
||||||
|
|
||||||
res = get_func(group, point, bn_x, bn_y, bn_ctx)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
x = self._backend._bn_to_int(bn_x)
|
|
||||||
y = self._backend._bn_to_int(bn_y)
|
|
||||||
|
|
||||||
return ec.EllipticCurvePublicNumbers(x=x, y=y, curve=self._curve)
|
|
||||||
|
|
||||||
def _encode_point(self, format: serialization.PublicFormat) -> bytes:
|
|
||||||
if format is serialization.PublicFormat.CompressedPoint:
|
|
||||||
conversion = self._backend._lib.POINT_CONVERSION_COMPRESSED
|
|
||||||
else:
|
|
||||||
assert format is serialization.PublicFormat.UncompressedPoint
|
|
||||||
conversion = self._backend._lib.POINT_CONVERSION_UNCOMPRESSED
|
|
||||||
|
|
||||||
group = self._backend._lib.EC_KEY_get0_group(self._ec_key)
|
|
||||||
self._backend.openssl_assert(group != self._backend._ffi.NULL)
|
|
||||||
point = self._backend._lib.EC_KEY_get0_public_key(self._ec_key)
|
|
||||||
self._backend.openssl_assert(point != self._backend._ffi.NULL)
|
|
||||||
with self._backend._tmp_bn_ctx() as bn_ctx:
|
|
||||||
buflen = self._backend._lib.EC_POINT_point2oct(
|
|
||||||
group, point, conversion, self._backend._ffi.NULL, 0, bn_ctx
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(buflen > 0)
|
|
||||||
buf = self._backend._ffi.new("char[]", buflen)
|
|
||||||
res = self._backend._lib.EC_POINT_point2oct(
|
|
||||||
group, point, conversion, buf, buflen, bn_ctx
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(buflen == res)
|
|
||||||
|
|
||||||
return self._backend._ffi.buffer(buf)[:]
|
|
||||||
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
if (
|
|
||||||
encoding is serialization.Encoding.X962
|
|
||||||
or format is serialization.PublicFormat.CompressedPoint
|
|
||||||
or format is serialization.PublicFormat.UncompressedPoint
|
|
||||||
):
|
|
||||||
if encoding is not serialization.Encoding.X962 or format not in (
|
|
||||||
serialization.PublicFormat.CompressedPoint,
|
|
||||||
serialization.PublicFormat.UncompressedPoint,
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"X962 encoding must be used with CompressedPoint or "
|
|
||||||
"UncompressedPoint format"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._encode_point(format)
|
|
||||||
else:
|
|
||||||
return self._backend._public_key_bytes(
|
|
||||||
encoding, format, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def verify(
|
|
||||||
self,
|
|
||||||
signature: bytes,
|
|
||||||
data: bytes,
|
|
||||||
signature_algorithm: ec.EllipticCurveSignatureAlgorithm,
|
|
||||||
) -> None:
|
|
||||||
_check_signature_algorithm(signature_algorithm)
|
|
||||||
data, _ = _calculate_digest_and_algorithm(
|
|
||||||
data,
|
|
||||||
signature_algorithm.algorithm,
|
|
||||||
)
|
|
||||||
_ecdsa_sig_verify(self._backend, self, signature, data)
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography import exceptions
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
|
|
||||||
Ed25519PrivateKey,
|
|
||||||
Ed25519PublicKey,
|
|
||||||
_ED25519_KEY_SIZE,
|
|
||||||
_ED25519_SIG_SIZE,
|
|
||||||
)
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
class _Ed25519PublicKey(Ed25519PublicKey):
|
|
||||||
def __init__(self, backend: "Backend", evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
if (
|
|
||||||
encoding is serialization.Encoding.Raw
|
|
||||||
or format is serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
encoding is not serialization.Encoding.Raw
|
|
||||||
or format is not serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"When using Raw both encoding and format must be Raw"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._raw_public_bytes()
|
|
||||||
|
|
||||||
return self._backend._public_key_bytes(
|
|
||||||
encoding, format, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raw_public_bytes(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new("unsigned char []", _ED25519_KEY_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", _ED25519_KEY_SIZE)
|
|
||||||
res = self._backend._lib.EVP_PKEY_get_raw_public_key(
|
|
||||||
self._evp_pkey, buf, buflen
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _ED25519_KEY_SIZE)
|
|
||||||
return self._backend._ffi.buffer(buf, _ED25519_KEY_SIZE)[:]
|
|
||||||
|
|
||||||
def verify(self, signature: bytes, data: bytes) -> None:
|
|
||||||
evp_md_ctx = self._backend._lib.EVP_MD_CTX_new()
|
|
||||||
self._backend.openssl_assert(evp_md_ctx != self._backend._ffi.NULL)
|
|
||||||
evp_md_ctx = self._backend._ffi.gc(
|
|
||||||
evp_md_ctx, self._backend._lib.EVP_MD_CTX_free
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_DigestVerifyInit(
|
|
||||||
evp_md_ctx,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._evp_pkey,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
res = self._backend._lib.EVP_DigestVerify(
|
|
||||||
evp_md_ctx, signature, len(signature), data, len(data)
|
|
||||||
)
|
|
||||||
if res != 1:
|
|
||||||
self._backend._consume_errors()
|
|
||||||
raise exceptions.InvalidSignature
|
|
||||||
|
|
||||||
|
|
||||||
class _Ed25519PrivateKey(Ed25519PrivateKey):
|
|
||||||
def __init__(self, backend: "Backend", evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
def public_key(self) -> Ed25519PublicKey:
|
|
||||||
buf = self._backend._ffi.new("unsigned char []", _ED25519_KEY_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", _ED25519_KEY_SIZE)
|
|
||||||
res = self._backend._lib.EVP_PKEY_get_raw_public_key(
|
|
||||||
self._evp_pkey, buf, buflen
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _ED25519_KEY_SIZE)
|
|
||||||
public_bytes = self._backend._ffi.buffer(buf)[:]
|
|
||||||
return self._backend.ed25519_load_public_bytes(public_bytes)
|
|
||||||
|
|
||||||
def sign(self, data: bytes) -> bytes:
|
|
||||||
evp_md_ctx = self._backend._lib.EVP_MD_CTX_new()
|
|
||||||
self._backend.openssl_assert(evp_md_ctx != self._backend._ffi.NULL)
|
|
||||||
evp_md_ctx = self._backend._ffi.gc(
|
|
||||||
evp_md_ctx, self._backend._lib.EVP_MD_CTX_free
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_DigestSignInit(
|
|
||||||
evp_md_ctx,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._evp_pkey,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
buf = self._backend._ffi.new("unsigned char[]", _ED25519_SIG_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", len(buf))
|
|
||||||
res = self._backend._lib.EVP_DigestSign(
|
|
||||||
evp_md_ctx, buf, buflen, data, len(data)
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _ED25519_SIG_SIZE)
|
|
||||||
return self._backend._ffi.buffer(buf, buflen[0])[:]
|
|
||||||
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
if (
|
|
||||||
encoding is serialization.Encoding.Raw
|
|
||||||
or format is serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
format is not serialization.PrivateFormat.Raw
|
|
||||||
or encoding is not serialization.Encoding.Raw
|
|
||||||
or not isinstance(
|
|
||||||
encryption_algorithm, serialization.NoEncryption
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"When using Raw both encoding and format must be Raw "
|
|
||||||
"and encryption_algorithm must be NoEncryption()"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._raw_private_bytes()
|
|
||||||
|
|
||||||
return self._backend._private_key_bytes(
|
|
||||||
encoding, format, encryption_algorithm, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raw_private_bytes(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new("unsigned char []", _ED25519_KEY_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", _ED25519_KEY_SIZE)
|
|
||||||
res = self._backend._lib.EVP_PKEY_get_raw_private_key(
|
|
||||||
self._evp_pkey, buf, buflen
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _ED25519_KEY_SIZE)
|
|
||||||
return self._backend._ffi.buffer(buf, _ED25519_KEY_SIZE)[:]
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography import exceptions
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.ed448 import (
|
|
||||||
Ed448PrivateKey,
|
|
||||||
Ed448PublicKey,
|
|
||||||
)
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
_ED448_KEY_SIZE = 57
|
|
||||||
_ED448_SIG_SIZE = 114
|
|
||||||
|
|
||||||
|
|
||||||
class _Ed448PublicKey(Ed448PublicKey):
|
|
||||||
def __init__(self, backend: "Backend", evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
if (
|
|
||||||
encoding is serialization.Encoding.Raw
|
|
||||||
or format is serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
encoding is not serialization.Encoding.Raw
|
|
||||||
or format is not serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"When using Raw both encoding and format must be Raw"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._raw_public_bytes()
|
|
||||||
|
|
||||||
return self._backend._public_key_bytes(
|
|
||||||
encoding, format, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raw_public_bytes(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new("unsigned char []", _ED448_KEY_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", _ED448_KEY_SIZE)
|
|
||||||
res = self._backend._lib.EVP_PKEY_get_raw_public_key(
|
|
||||||
self._evp_pkey, buf, buflen
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _ED448_KEY_SIZE)
|
|
||||||
return self._backend._ffi.buffer(buf, _ED448_KEY_SIZE)[:]
|
|
||||||
|
|
||||||
def verify(self, signature: bytes, data: bytes) -> None:
|
|
||||||
evp_md_ctx = self._backend._lib.EVP_MD_CTX_new()
|
|
||||||
self._backend.openssl_assert(evp_md_ctx != self._backend._ffi.NULL)
|
|
||||||
evp_md_ctx = self._backend._ffi.gc(
|
|
||||||
evp_md_ctx, self._backend._lib.EVP_MD_CTX_free
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_DigestVerifyInit(
|
|
||||||
evp_md_ctx,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._evp_pkey,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
res = self._backend._lib.EVP_DigestVerify(
|
|
||||||
evp_md_ctx, signature, len(signature), data, len(data)
|
|
||||||
)
|
|
||||||
if res != 1:
|
|
||||||
self._backend._consume_errors()
|
|
||||||
raise exceptions.InvalidSignature
|
|
||||||
|
|
||||||
|
|
||||||
class _Ed448PrivateKey(Ed448PrivateKey):
|
|
||||||
def __init__(self, backend: "Backend", evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
def public_key(self) -> Ed448PublicKey:
|
|
||||||
buf = self._backend._ffi.new("unsigned char []", _ED448_KEY_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", _ED448_KEY_SIZE)
|
|
||||||
res = self._backend._lib.EVP_PKEY_get_raw_public_key(
|
|
||||||
self._evp_pkey, buf, buflen
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _ED448_KEY_SIZE)
|
|
||||||
public_bytes = self._backend._ffi.buffer(buf)[:]
|
|
||||||
return self._backend.ed448_load_public_bytes(public_bytes)
|
|
||||||
|
|
||||||
def sign(self, data: bytes) -> bytes:
|
|
||||||
evp_md_ctx = self._backend._lib.EVP_MD_CTX_new()
|
|
||||||
self._backend.openssl_assert(evp_md_ctx != self._backend._ffi.NULL)
|
|
||||||
evp_md_ctx = self._backend._ffi.gc(
|
|
||||||
evp_md_ctx, self._backend._lib.EVP_MD_CTX_free
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_DigestSignInit(
|
|
||||||
evp_md_ctx,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._evp_pkey,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
buf = self._backend._ffi.new("unsigned char[]", _ED448_SIG_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", len(buf))
|
|
||||||
res = self._backend._lib.EVP_DigestSign(
|
|
||||||
evp_md_ctx, buf, buflen, data, len(data)
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _ED448_SIG_SIZE)
|
|
||||||
return self._backend._ffi.buffer(buf, buflen[0])[:]
|
|
||||||
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
if (
|
|
||||||
encoding is serialization.Encoding.Raw
|
|
||||||
or format is serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
format is not serialization.PrivateFormat.Raw
|
|
||||||
or encoding is not serialization.Encoding.Raw
|
|
||||||
or not isinstance(
|
|
||||||
encryption_algorithm, serialization.NoEncryption
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"When using Raw both encoding and format must be Raw "
|
|
||||||
"and encryption_algorithm must be NoEncryption()"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._raw_private_bytes()
|
|
||||||
|
|
||||||
return self._backend._private_key_bytes(
|
|
||||||
encoding, format, encryption_algorithm, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raw_private_bytes(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new("unsigned char []", _ED448_KEY_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", _ED448_KEY_SIZE)
|
|
||||||
res = self._backend._lib.EVP_PKEY_get_raw_private_key(
|
|
||||||
self._evp_pkey, buf, buflen
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _ED448_KEY_SIZE)
|
|
||||||
return self._backend._ffi.buffer(buf, _ED448_KEY_SIZE)[:]
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
from cryptography import x509
|
|
||||||
|
|
||||||
|
|
||||||
_CRLREASONFLAGS = {
|
|
||||||
x509.ReasonFlags.key_compromise: 1,
|
|
||||||
x509.ReasonFlags.ca_compromise: 2,
|
|
||||||
x509.ReasonFlags.affiliation_changed: 3,
|
|
||||||
x509.ReasonFlags.superseded: 4,
|
|
||||||
x509.ReasonFlags.cessation_of_operation: 5,
|
|
||||||
x509.ReasonFlags.certificate_hold: 6,
|
|
||||||
x509.ReasonFlags.privilege_withdrawn: 7,
|
|
||||||
x509.ReasonFlags.aa_compromise: 8,
|
|
||||||
}
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
|
|
||||||
from cryptography.hazmat.primitives import hashes
|
|
||||||
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
class _HashContext(hashes.HashContext):
|
|
||||||
def __init__(
|
|
||||||
self, backend: "Backend", algorithm: hashes.HashAlgorithm, ctx=None
|
|
||||||
) -> None:
|
|
||||||
self._algorithm = algorithm
|
|
||||||
|
|
||||||
self._backend = backend
|
|
||||||
|
|
||||||
if ctx is None:
|
|
||||||
ctx = self._backend._lib.EVP_MD_CTX_new()
|
|
||||||
ctx = self._backend._ffi.gc(
|
|
||||||
ctx, self._backend._lib.EVP_MD_CTX_free
|
|
||||||
)
|
|
||||||
evp_md = self._backend._evp_md_from_algorithm(algorithm)
|
|
||||||
if evp_md == self._backend._ffi.NULL:
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"{} is not a supported hash on this backend.".format(
|
|
||||||
algorithm.name
|
|
||||||
),
|
|
||||||
_Reasons.UNSUPPORTED_HASH,
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_DigestInit_ex(
|
|
||||||
ctx, evp_md, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
|
|
||||||
self._ctx = ctx
|
|
||||||
|
|
||||||
@property
|
|
||||||
def algorithm(self) -> hashes.HashAlgorithm:
|
|
||||||
return self._algorithm
|
|
||||||
|
|
||||||
def copy(self) -> "_HashContext":
|
|
||||||
copied_ctx = self._backend._lib.EVP_MD_CTX_new()
|
|
||||||
copied_ctx = self._backend._ffi.gc(
|
|
||||||
copied_ctx, self._backend._lib.EVP_MD_CTX_free
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_MD_CTX_copy_ex(copied_ctx, self._ctx)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
return _HashContext(self._backend, self.algorithm, ctx=copied_ctx)
|
|
||||||
|
|
||||||
def update(self, data: bytes) -> None:
|
|
||||||
data_ptr = self._backend._ffi.from_buffer(data)
|
|
||||||
res = self._backend._lib.EVP_DigestUpdate(
|
|
||||||
self._ctx, data_ptr, len(data)
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
|
|
||||||
def finalize(self) -> bytes:
|
|
||||||
if isinstance(self.algorithm, hashes.ExtendableOutputFunction):
|
|
||||||
# extendable output functions use a different finalize
|
|
||||||
return self._finalize_xof()
|
|
||||||
else:
|
|
||||||
buf = self._backend._ffi.new(
|
|
||||||
"unsigned char[]", self._backend._lib.EVP_MAX_MD_SIZE
|
|
||||||
)
|
|
||||||
outlen = self._backend._ffi.new("unsigned int *")
|
|
||||||
res = self._backend._lib.EVP_DigestFinal_ex(self._ctx, buf, outlen)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
self._backend.openssl_assert(
|
|
||||||
outlen[0] == self.algorithm.digest_size
|
|
||||||
)
|
|
||||||
return self._backend._ffi.buffer(buf)[: outlen[0]]
|
|
||||||
|
|
||||||
def _finalize_xof(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new(
|
|
||||||
"unsigned char[]", self.algorithm.digest_size
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_DigestFinalXOF(
|
|
||||||
self._ctx, buf, self.algorithm.digest_size
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
return self._backend._ffi.buffer(buf)[: self.algorithm.digest_size]
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import (
|
|
||||||
InvalidSignature,
|
|
||||||
UnsupportedAlgorithm,
|
|
||||||
_Reasons,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.primitives import constant_time, hashes
|
|
||||||
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
class _HMACContext(hashes.HashContext):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
backend: "Backend",
|
|
||||||
key: bytes,
|
|
||||||
algorithm: hashes.HashAlgorithm,
|
|
||||||
ctx=None,
|
|
||||||
):
|
|
||||||
self._algorithm = algorithm
|
|
||||||
self._backend = backend
|
|
||||||
|
|
||||||
if ctx is None:
|
|
||||||
ctx = self._backend._lib.HMAC_CTX_new()
|
|
||||||
self._backend.openssl_assert(ctx != self._backend._ffi.NULL)
|
|
||||||
ctx = self._backend._ffi.gc(ctx, self._backend._lib.HMAC_CTX_free)
|
|
||||||
evp_md = self._backend._evp_md_from_algorithm(algorithm)
|
|
||||||
if evp_md == self._backend._ffi.NULL:
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"{} is not a supported hash on this backend".format(
|
|
||||||
algorithm.name
|
|
||||||
),
|
|
||||||
_Reasons.UNSUPPORTED_HASH,
|
|
||||||
)
|
|
||||||
key_ptr = self._backend._ffi.from_buffer(key)
|
|
||||||
res = self._backend._lib.HMAC_Init_ex(
|
|
||||||
ctx, key_ptr, len(key), evp_md, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
|
|
||||||
self._ctx = ctx
|
|
||||||
self._key = key
|
|
||||||
|
|
||||||
@property
|
|
||||||
def algorithm(self) -> hashes.HashAlgorithm:
|
|
||||||
return self._algorithm
|
|
||||||
|
|
||||||
def copy(self) -> "_HMACContext":
|
|
||||||
copied_ctx = self._backend._lib.HMAC_CTX_new()
|
|
||||||
self._backend.openssl_assert(copied_ctx != self._backend._ffi.NULL)
|
|
||||||
copied_ctx = self._backend._ffi.gc(
|
|
||||||
copied_ctx, self._backend._lib.HMAC_CTX_free
|
|
||||||
)
|
|
||||||
res = self._backend._lib.HMAC_CTX_copy(copied_ctx, self._ctx)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
return _HMACContext(
|
|
||||||
self._backend, self._key, self.algorithm, ctx=copied_ctx
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, data: bytes) -> None:
|
|
||||||
data_ptr = self._backend._ffi.from_buffer(data)
|
|
||||||
res = self._backend._lib.HMAC_Update(self._ctx, data_ptr, len(data))
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
|
|
||||||
def finalize(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new(
|
|
||||||
"unsigned char[]", self._backend._lib.EVP_MAX_MD_SIZE
|
|
||||||
)
|
|
||||||
outlen = self._backend._ffi.new("unsigned int *")
|
|
||||||
res = self._backend._lib.HMAC_Final(self._ctx, buf, outlen)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
self._backend.openssl_assert(outlen[0] == self.algorithm.digest_size)
|
|
||||||
return self._backend._ffi.buffer(buf)[: outlen[0]]
|
|
||||||
|
|
||||||
def verify(self, signature: bytes) -> None:
|
|
||||||
digest = self.finalize()
|
|
||||||
if not constant_time.bytes_eq(digest, signature):
|
|
||||||
raise InvalidSignature("Signature did not match digest.")
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import InvalidSignature
|
|
||||||
from cryptography.hazmat.primitives import constant_time
|
|
||||||
|
|
||||||
|
|
||||||
_POLY1305_TAG_SIZE = 16
|
|
||||||
_POLY1305_KEY_SIZE = 32
|
|
||||||
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
class _Poly1305Context:
|
|
||||||
def __init__(self, backend: "Backend", key: bytes) -> None:
|
|
||||||
self._backend = backend
|
|
||||||
|
|
||||||
key_ptr = self._backend._ffi.from_buffer(key)
|
|
||||||
# This function copies the key into OpenSSL-owned memory so we don't
|
|
||||||
# need to retain it ourselves
|
|
||||||
evp_pkey = self._backend._lib.EVP_PKEY_new_raw_private_key(
|
|
||||||
self._backend._lib.NID_poly1305,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
key_ptr,
|
|
||||||
len(key),
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(evp_pkey != self._backend._ffi.NULL)
|
|
||||||
self._evp_pkey = self._backend._ffi.gc(
|
|
||||||
evp_pkey, self._backend._lib.EVP_PKEY_free
|
|
||||||
)
|
|
||||||
ctx = self._backend._lib.EVP_MD_CTX_new()
|
|
||||||
self._backend.openssl_assert(ctx != self._backend._ffi.NULL)
|
|
||||||
self._ctx = self._backend._ffi.gc(
|
|
||||||
ctx, self._backend._lib.EVP_MD_CTX_free
|
|
||||||
)
|
|
||||||
res = self._backend._lib.EVP_DigestSignInit(
|
|
||||||
self._ctx,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._evp_pkey,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
def update(self, data: bytes) -> None:
|
|
||||||
data_ptr = self._backend._ffi.from_buffer(data)
|
|
||||||
res = self._backend._lib.EVP_DigestSignUpdate(
|
|
||||||
self._ctx, data_ptr, len(data)
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
|
|
||||||
def finalize(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new("unsigned char[]", _POLY1305_TAG_SIZE)
|
|
||||||
outlen = self._backend._ffi.new("size_t *", _POLY1305_TAG_SIZE)
|
|
||||||
res = self._backend._lib.EVP_DigestSignFinal(self._ctx, buf, outlen)
|
|
||||||
self._backend.openssl_assert(res != 0)
|
|
||||||
self._backend.openssl_assert(outlen[0] == _POLY1305_TAG_SIZE)
|
|
||||||
return self._backend._ffi.buffer(buf)[: outlen[0]]
|
|
||||||
|
|
||||||
def verify(self, tag: bytes) -> None:
|
|
||||||
mac = self.finalize()
|
|
||||||
if not constant_time.bytes_eq(mac, tag):
|
|
||||||
raise InvalidSignature("Value did not match computed tag.")
|
|
||||||
@@ -1,567 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.exceptions import (
|
|
||||||
InvalidSignature,
|
|
||||||
UnsupportedAlgorithm,
|
|
||||||
_Reasons,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.backends.openssl.utils import (
|
|
||||||
_calculate_digest_and_algorithm,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import (
|
|
||||||
utils as asym_utils,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.padding import (
|
|
||||||
AsymmetricPadding,
|
|
||||||
MGF1,
|
|
||||||
OAEP,
|
|
||||||
PKCS1v15,
|
|
||||||
PSS,
|
|
||||||
_Auto,
|
|
||||||
_DigestLength,
|
|
||||||
_MaxLength,
|
|
||||||
calculate_max_pss_salt_length,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
|
||||||
RSAPrivateKey,
|
|
||||||
RSAPrivateNumbers,
|
|
||||||
RSAPublicKey,
|
|
||||||
RSAPublicNumbers,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
def _get_rsa_pss_salt_length(
|
|
||||||
backend: "Backend",
|
|
||||||
pss: PSS,
|
|
||||||
key: typing.Union[RSAPrivateKey, RSAPublicKey],
|
|
||||||
hash_algorithm: hashes.HashAlgorithm,
|
|
||||||
) -> int:
|
|
||||||
salt = pss._salt_length
|
|
||||||
|
|
||||||
if isinstance(salt, _MaxLength):
|
|
||||||
return calculate_max_pss_salt_length(key, hash_algorithm)
|
|
||||||
elif isinstance(salt, _DigestLength):
|
|
||||||
return hash_algorithm.digest_size
|
|
||||||
elif isinstance(salt, _Auto):
|
|
||||||
if isinstance(key, RSAPrivateKey):
|
|
||||||
raise ValueError(
|
|
||||||
"PSS salt length can only be set to AUTO when verifying"
|
|
||||||
)
|
|
||||||
return backend._lib.RSA_PSS_SALTLEN_AUTO
|
|
||||||
else:
|
|
||||||
return salt
|
|
||||||
|
|
||||||
|
|
||||||
def _enc_dec_rsa(
|
|
||||||
backend: "Backend",
|
|
||||||
key: typing.Union["_RSAPrivateKey", "_RSAPublicKey"],
|
|
||||||
data: bytes,
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
) -> bytes:
|
|
||||||
if not isinstance(padding, AsymmetricPadding):
|
|
||||||
raise TypeError("Padding must be an instance of AsymmetricPadding.")
|
|
||||||
|
|
||||||
if isinstance(padding, PKCS1v15):
|
|
||||||
padding_enum = backend._lib.RSA_PKCS1_PADDING
|
|
||||||
elif isinstance(padding, OAEP):
|
|
||||||
padding_enum = backend._lib.RSA_PKCS1_OAEP_PADDING
|
|
||||||
|
|
||||||
if not isinstance(padding._mgf, MGF1):
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"Only MGF1 is supported by this backend.",
|
|
||||||
_Reasons.UNSUPPORTED_MGF,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not backend.rsa_padding_supported(padding):
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"This combination of padding and hash algorithm is not "
|
|
||||||
"supported by this backend.",
|
|
||||||
_Reasons.UNSUPPORTED_PADDING,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"{} is not supported by this backend.".format(padding.name),
|
|
||||||
_Reasons.UNSUPPORTED_PADDING,
|
|
||||||
)
|
|
||||||
|
|
||||||
return _enc_dec_rsa_pkey_ctx(backend, key, data, padding_enum, padding)
|
|
||||||
|
|
||||||
|
|
||||||
def _enc_dec_rsa_pkey_ctx(
|
|
||||||
backend: "Backend",
|
|
||||||
key: typing.Union["_RSAPrivateKey", "_RSAPublicKey"],
|
|
||||||
data: bytes,
|
|
||||||
padding_enum: int,
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
) -> bytes:
|
|
||||||
init: typing.Callable[[typing.Any], int]
|
|
||||||
crypt: typing.Callable[[typing.Any, typing.Any, int, bytes, int], int]
|
|
||||||
if isinstance(key, _RSAPublicKey):
|
|
||||||
init = backend._lib.EVP_PKEY_encrypt_init
|
|
||||||
crypt = backend._lib.EVP_PKEY_encrypt
|
|
||||||
else:
|
|
||||||
init = backend._lib.EVP_PKEY_decrypt_init
|
|
||||||
crypt = backend._lib.EVP_PKEY_decrypt
|
|
||||||
|
|
||||||
pkey_ctx = backend._lib.EVP_PKEY_CTX_new(key._evp_pkey, backend._ffi.NULL)
|
|
||||||
backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
|
|
||||||
pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
|
|
||||||
res = init(pkey_ctx)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, padding_enum)
|
|
||||||
backend.openssl_assert(res > 0)
|
|
||||||
buf_size = backend._lib.EVP_PKEY_size(key._evp_pkey)
|
|
||||||
backend.openssl_assert(buf_size > 0)
|
|
||||||
if isinstance(padding, OAEP):
|
|
||||||
mgf1_md = backend._evp_md_non_null_from_algorithm(
|
|
||||||
padding._mgf._algorithm
|
|
||||||
)
|
|
||||||
res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1_md)
|
|
||||||
backend.openssl_assert(res > 0)
|
|
||||||
oaep_md = backend._evp_md_non_null_from_algorithm(padding._algorithm)
|
|
||||||
res = backend._lib.EVP_PKEY_CTX_set_rsa_oaep_md(pkey_ctx, oaep_md)
|
|
||||||
backend.openssl_assert(res > 0)
|
|
||||||
|
|
||||||
if (
|
|
||||||
isinstance(padding, OAEP)
|
|
||||||
and padding._label is not None
|
|
||||||
and len(padding._label) > 0
|
|
||||||
):
|
|
||||||
# set0_rsa_oaep_label takes ownership of the char * so we need to
|
|
||||||
# copy it into some new memory
|
|
||||||
labelptr = backend._lib.OPENSSL_malloc(len(padding._label))
|
|
||||||
backend.openssl_assert(labelptr != backend._ffi.NULL)
|
|
||||||
backend._ffi.memmove(labelptr, padding._label, len(padding._label))
|
|
||||||
res = backend._lib.EVP_PKEY_CTX_set0_rsa_oaep_label(
|
|
||||||
pkey_ctx, labelptr, len(padding._label)
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
outlen = backend._ffi.new("size_t *", buf_size)
|
|
||||||
buf = backend._ffi.new("unsigned char[]", buf_size)
|
|
||||||
# Everything from this line onwards is written with the goal of being as
|
|
||||||
# constant-time as is practical given the constraints of Python and our
|
|
||||||
# API. See Bleichenbacher's '98 attack on RSA, and its many many variants.
|
|
||||||
# As such, you should not attempt to change this (particularly to "clean it
|
|
||||||
# up") without understanding why it was written this way (see
|
|
||||||
# Chesterton's Fence), and without measuring to verify you have not
|
|
||||||
# introduced observable time differences.
|
|
||||||
res = crypt(pkey_ctx, buf, outlen, data, len(data))
|
|
||||||
resbuf = backend._ffi.buffer(buf)[: outlen[0]]
|
|
||||||
backend._lib.ERR_clear_error()
|
|
||||||
if res <= 0:
|
|
||||||
raise ValueError("Encryption/decryption failed.")
|
|
||||||
return resbuf
|
|
||||||
|
|
||||||
|
|
||||||
def _rsa_sig_determine_padding(
|
|
||||||
backend: "Backend",
|
|
||||||
key: typing.Union["_RSAPrivateKey", "_RSAPublicKey"],
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: typing.Optional[hashes.HashAlgorithm],
|
|
||||||
) -> int:
|
|
||||||
if not isinstance(padding, AsymmetricPadding):
|
|
||||||
raise TypeError("Expected provider of AsymmetricPadding.")
|
|
||||||
|
|
||||||
pkey_size = backend._lib.EVP_PKEY_size(key._evp_pkey)
|
|
||||||
backend.openssl_assert(pkey_size > 0)
|
|
||||||
|
|
||||||
if isinstance(padding, PKCS1v15):
|
|
||||||
# Hash algorithm is ignored for PKCS1v15-padding, may be None.
|
|
||||||
padding_enum = backend._lib.RSA_PKCS1_PADDING
|
|
||||||
elif isinstance(padding, PSS):
|
|
||||||
if not isinstance(padding._mgf, MGF1):
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"Only MGF1 is supported by this backend.",
|
|
||||||
_Reasons.UNSUPPORTED_MGF,
|
|
||||||
)
|
|
||||||
|
|
||||||
# PSS padding requires a hash algorithm
|
|
||||||
if not isinstance(algorithm, hashes.HashAlgorithm):
|
|
||||||
raise TypeError("Expected instance of hashes.HashAlgorithm.")
|
|
||||||
|
|
||||||
# Size of key in bytes - 2 is the maximum
|
|
||||||
# PSS signature length (salt length is checked later)
|
|
||||||
if pkey_size - algorithm.digest_size - 2 < 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Digest too large for key size. Use a larger "
|
|
||||||
"key or different digest."
|
|
||||||
)
|
|
||||||
|
|
||||||
padding_enum = backend._lib.RSA_PKCS1_PSS_PADDING
|
|
||||||
else:
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"{} is not supported by this backend.".format(padding.name),
|
|
||||||
_Reasons.UNSUPPORTED_PADDING,
|
|
||||||
)
|
|
||||||
|
|
||||||
return padding_enum
|
|
||||||
|
|
||||||
|
|
||||||
# Hash algorithm can be absent (None) to initialize the context without setting
|
|
||||||
# any message digest algorithm. This is currently only valid for the PKCS1v15
|
|
||||||
# padding type, where it means that the signature data is encoded/decoded
|
|
||||||
# as provided, without being wrapped in a DigestInfo structure.
|
|
||||||
def _rsa_sig_setup(
|
|
||||||
backend: "Backend",
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: typing.Optional[hashes.HashAlgorithm],
|
|
||||||
key: typing.Union["_RSAPublicKey", "_RSAPrivateKey"],
|
|
||||||
init_func: typing.Callable[[typing.Any], int],
|
|
||||||
):
|
|
||||||
padding_enum = _rsa_sig_determine_padding(backend, key, padding, algorithm)
|
|
||||||
pkey_ctx = backend._lib.EVP_PKEY_CTX_new(key._evp_pkey, backend._ffi.NULL)
|
|
||||||
backend.openssl_assert(pkey_ctx != backend._ffi.NULL)
|
|
||||||
pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free)
|
|
||||||
res = init_func(pkey_ctx)
|
|
||||||
if res != 1:
|
|
||||||
errors = backend._consume_errors()
|
|
||||||
raise ValueError("Unable to sign/verify with this key", errors)
|
|
||||||
|
|
||||||
if algorithm is not None:
|
|
||||||
evp_md = backend._evp_md_non_null_from_algorithm(algorithm)
|
|
||||||
res = backend._lib.EVP_PKEY_CTX_set_signature_md(pkey_ctx, evp_md)
|
|
||||||
if res <= 0:
|
|
||||||
backend._consume_errors()
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"{} is not supported by this backend for RSA signing.".format(
|
|
||||||
algorithm.name
|
|
||||||
),
|
|
||||||
_Reasons.UNSUPPORTED_HASH,
|
|
||||||
)
|
|
||||||
res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, padding_enum)
|
|
||||||
if res <= 0:
|
|
||||||
backend._consume_errors()
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"{} is not supported for the RSA signature operation.".format(
|
|
||||||
padding.name
|
|
||||||
),
|
|
||||||
_Reasons.UNSUPPORTED_PADDING,
|
|
||||||
)
|
|
||||||
if isinstance(padding, PSS):
|
|
||||||
assert isinstance(algorithm, hashes.HashAlgorithm)
|
|
||||||
res = backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen(
|
|
||||||
pkey_ctx,
|
|
||||||
_get_rsa_pss_salt_length(backend, padding, key, algorithm),
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res > 0)
|
|
||||||
|
|
||||||
mgf1_md = backend._evp_md_non_null_from_algorithm(
|
|
||||||
padding._mgf._algorithm
|
|
||||||
)
|
|
||||||
res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1_md)
|
|
||||||
backend.openssl_assert(res > 0)
|
|
||||||
|
|
||||||
return pkey_ctx
|
|
||||||
|
|
||||||
|
|
||||||
def _rsa_sig_sign(
|
|
||||||
backend: "Backend",
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: hashes.HashAlgorithm,
|
|
||||||
private_key: "_RSAPrivateKey",
|
|
||||||
data: bytes,
|
|
||||||
) -> bytes:
|
|
||||||
pkey_ctx = _rsa_sig_setup(
|
|
||||||
backend,
|
|
||||||
padding,
|
|
||||||
algorithm,
|
|
||||||
private_key,
|
|
||||||
backend._lib.EVP_PKEY_sign_init,
|
|
||||||
)
|
|
||||||
buflen = backend._ffi.new("size_t *")
|
|
||||||
res = backend._lib.EVP_PKEY_sign(
|
|
||||||
pkey_ctx, backend._ffi.NULL, buflen, data, len(data)
|
|
||||||
)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
buf = backend._ffi.new("unsigned char[]", buflen[0])
|
|
||||||
res = backend._lib.EVP_PKEY_sign(pkey_ctx, buf, buflen, data, len(data))
|
|
||||||
if res != 1:
|
|
||||||
errors = backend._consume_errors_with_text()
|
|
||||||
raise ValueError(
|
|
||||||
"Digest or salt length too long for key size. Use a larger key "
|
|
||||||
"or shorter salt length if you are specifying a PSS salt",
|
|
||||||
errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
return backend._ffi.buffer(buf)[:]
|
|
||||||
|
|
||||||
|
|
||||||
def _rsa_sig_verify(
|
|
||||||
backend: "Backend",
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: hashes.HashAlgorithm,
|
|
||||||
public_key: "_RSAPublicKey",
|
|
||||||
signature: bytes,
|
|
||||||
data: bytes,
|
|
||||||
) -> None:
|
|
||||||
pkey_ctx = _rsa_sig_setup(
|
|
||||||
backend,
|
|
||||||
padding,
|
|
||||||
algorithm,
|
|
||||||
public_key,
|
|
||||||
backend._lib.EVP_PKEY_verify_init,
|
|
||||||
)
|
|
||||||
res = backend._lib.EVP_PKEY_verify(
|
|
||||||
pkey_ctx, signature, len(signature), data, len(data)
|
|
||||||
)
|
|
||||||
# The previous call can return negative numbers in the event of an
|
|
||||||
# error. This is not a signature failure but we need to fail if it
|
|
||||||
# occurs.
|
|
||||||
backend.openssl_assert(res >= 0)
|
|
||||||
if res == 0:
|
|
||||||
backend._consume_errors()
|
|
||||||
raise InvalidSignature
|
|
||||||
|
|
||||||
|
|
||||||
def _rsa_sig_recover(
|
|
||||||
backend: "Backend",
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: typing.Optional[hashes.HashAlgorithm],
|
|
||||||
public_key: "_RSAPublicKey",
|
|
||||||
signature: bytes,
|
|
||||||
) -> bytes:
|
|
||||||
pkey_ctx = _rsa_sig_setup(
|
|
||||||
backend,
|
|
||||||
padding,
|
|
||||||
algorithm,
|
|
||||||
public_key,
|
|
||||||
backend._lib.EVP_PKEY_verify_recover_init,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Attempt to keep the rest of the code in this function as constant/time
|
|
||||||
# as possible. See the comment in _enc_dec_rsa_pkey_ctx. Note that the
|
|
||||||
# buflen parameter is used even though its value may be undefined in the
|
|
||||||
# error case. Due to the tolerant nature of Python slicing this does not
|
|
||||||
# trigger any exceptions.
|
|
||||||
maxlen = backend._lib.EVP_PKEY_size(public_key._evp_pkey)
|
|
||||||
backend.openssl_assert(maxlen > 0)
|
|
||||||
buf = backend._ffi.new("unsigned char[]", maxlen)
|
|
||||||
buflen = backend._ffi.new("size_t *", maxlen)
|
|
||||||
res = backend._lib.EVP_PKEY_verify_recover(
|
|
||||||
pkey_ctx, buf, buflen, signature, len(signature)
|
|
||||||
)
|
|
||||||
resbuf = backend._ffi.buffer(buf)[: buflen[0]]
|
|
||||||
backend._lib.ERR_clear_error()
|
|
||||||
# Assume that all parameter errors are handled during the setup phase and
|
|
||||||
# any error here is due to invalid signature.
|
|
||||||
if res != 1:
|
|
||||||
raise InvalidSignature
|
|
||||||
return resbuf
|
|
||||||
|
|
||||||
|
|
||||||
class _RSAPrivateKey(RSAPrivateKey):
|
|
||||||
_evp_pkey: object
|
|
||||||
_rsa_cdata: object
|
|
||||||
_key_size: int
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, backend: "Backend", rsa_cdata, evp_pkey, _skip_check_key: bool
|
|
||||||
):
|
|
||||||
res: int
|
|
||||||
# RSA_check_key is slower in OpenSSL 3.0.0 due to improved
|
|
||||||
# primality checking. In normal use this is unlikely to be a problem
|
|
||||||
# since users don't load new keys constantly, but for TESTING we've
|
|
||||||
# added an init arg that allows skipping the checks. You should not
|
|
||||||
# use this in production code unless you understand the consequences.
|
|
||||||
if not _skip_check_key:
|
|
||||||
res = backend._lib.RSA_check_key(rsa_cdata)
|
|
||||||
if res != 1:
|
|
||||||
errors = backend._consume_errors_with_text()
|
|
||||||
raise ValueError("Invalid private key", errors)
|
|
||||||
# 2 is prime and passes an RSA key check, so we also check
|
|
||||||
# if p and q are odd just to be safe.
|
|
||||||
p = backend._ffi.new("BIGNUM **")
|
|
||||||
q = backend._ffi.new("BIGNUM **")
|
|
||||||
backend._lib.RSA_get0_factors(rsa_cdata, p, q)
|
|
||||||
backend.openssl_assert(p[0] != backend._ffi.NULL)
|
|
||||||
backend.openssl_assert(q[0] != backend._ffi.NULL)
|
|
||||||
p_odd = backend._lib.BN_is_odd(p[0])
|
|
||||||
q_odd = backend._lib.BN_is_odd(q[0])
|
|
||||||
if p_odd != 1 or q_odd != 1:
|
|
||||||
errors = backend._consume_errors_with_text()
|
|
||||||
raise ValueError("Invalid private key", errors)
|
|
||||||
|
|
||||||
# Blinding is on by default in many versions of OpenSSL, but let's
|
|
||||||
# just be conservative here.
|
|
||||||
res = backend._lib.RSA_blinding_on(rsa_cdata, backend._ffi.NULL)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
|
|
||||||
self._backend = backend
|
|
||||||
self._rsa_cdata = rsa_cdata
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
n = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.RSA_get0_key(
|
|
||||||
self._rsa_cdata,
|
|
||||||
n,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
|
|
||||||
self._key_size = self._backend._lib.BN_num_bits(n[0])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key_size(self) -> int:
|
|
||||||
return self._key_size
|
|
||||||
|
|
||||||
def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
|
|
||||||
key_size_bytes = (self.key_size + 7) // 8
|
|
||||||
if key_size_bytes != len(ciphertext):
|
|
||||||
raise ValueError("Ciphertext length must be equal to key size.")
|
|
||||||
|
|
||||||
return _enc_dec_rsa(self._backend, self, ciphertext, padding)
|
|
||||||
|
|
||||||
def public_key(self) -> RSAPublicKey:
|
|
||||||
ctx = self._backend._lib.RSAPublicKey_dup(self._rsa_cdata)
|
|
||||||
self._backend.openssl_assert(ctx != self._backend._ffi.NULL)
|
|
||||||
ctx = self._backend._ffi.gc(ctx, self._backend._lib.RSA_free)
|
|
||||||
evp_pkey = self._backend._rsa_cdata_to_evp_pkey(ctx)
|
|
||||||
return _RSAPublicKey(self._backend, ctx, evp_pkey)
|
|
||||||
|
|
||||||
def private_numbers(self) -> RSAPrivateNumbers:
|
|
||||||
n = self._backend._ffi.new("BIGNUM **")
|
|
||||||
e = self._backend._ffi.new("BIGNUM **")
|
|
||||||
d = self._backend._ffi.new("BIGNUM **")
|
|
||||||
p = self._backend._ffi.new("BIGNUM **")
|
|
||||||
q = self._backend._ffi.new("BIGNUM **")
|
|
||||||
dmp1 = self._backend._ffi.new("BIGNUM **")
|
|
||||||
dmq1 = self._backend._ffi.new("BIGNUM **")
|
|
||||||
iqmp = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.RSA_get0_key(self._rsa_cdata, n, e, d)
|
|
||||||
self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(e[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(d[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend._lib.RSA_get0_factors(self._rsa_cdata, p, q)
|
|
||||||
self._backend.openssl_assert(p[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(q[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend._lib.RSA_get0_crt_params(
|
|
||||||
self._rsa_cdata, dmp1, dmq1, iqmp
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(dmp1[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(dmq1[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(iqmp[0] != self._backend._ffi.NULL)
|
|
||||||
return RSAPrivateNumbers(
|
|
||||||
p=self._backend._bn_to_int(p[0]),
|
|
||||||
q=self._backend._bn_to_int(q[0]),
|
|
||||||
d=self._backend._bn_to_int(d[0]),
|
|
||||||
dmp1=self._backend._bn_to_int(dmp1[0]),
|
|
||||||
dmq1=self._backend._bn_to_int(dmq1[0]),
|
|
||||||
iqmp=self._backend._bn_to_int(iqmp[0]),
|
|
||||||
public_numbers=RSAPublicNumbers(
|
|
||||||
e=self._backend._bn_to_int(e[0]),
|
|
||||||
n=self._backend._bn_to_int(n[0]),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
return self._backend._private_key_bytes(
|
|
||||||
encoding,
|
|
||||||
format,
|
|
||||||
encryption_algorithm,
|
|
||||||
self,
|
|
||||||
self._evp_pkey,
|
|
||||||
self._rsa_cdata,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sign(
|
|
||||||
self,
|
|
||||||
data: bytes,
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
|
|
||||||
) -> bytes:
|
|
||||||
data, algorithm = _calculate_digest_and_algorithm(data, algorithm)
|
|
||||||
return _rsa_sig_sign(self._backend, padding, algorithm, self, data)
|
|
||||||
|
|
||||||
|
|
||||||
class _RSAPublicKey(RSAPublicKey):
|
|
||||||
_evp_pkey: object
|
|
||||||
_rsa_cdata: object
|
|
||||||
_key_size: int
|
|
||||||
|
|
||||||
def __init__(self, backend: "Backend", rsa_cdata, evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._rsa_cdata = rsa_cdata
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
n = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.RSA_get0_key(
|
|
||||||
self._rsa_cdata,
|
|
||||||
n,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
|
|
||||||
self._key_size = self._backend._lib.BN_num_bits(n[0])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def key_size(self) -> int:
|
|
||||||
return self._key_size
|
|
||||||
|
|
||||||
def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes:
|
|
||||||
return _enc_dec_rsa(self._backend, self, plaintext, padding)
|
|
||||||
|
|
||||||
def public_numbers(self) -> RSAPublicNumbers:
|
|
||||||
n = self._backend._ffi.new("BIGNUM **")
|
|
||||||
e = self._backend._ffi.new("BIGNUM **")
|
|
||||||
self._backend._lib.RSA_get0_key(
|
|
||||||
self._rsa_cdata, n, e, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(n[0] != self._backend._ffi.NULL)
|
|
||||||
self._backend.openssl_assert(e[0] != self._backend._ffi.NULL)
|
|
||||||
return RSAPublicNumbers(
|
|
||||||
e=self._backend._bn_to_int(e[0]),
|
|
||||||
n=self._backend._bn_to_int(n[0]),
|
|
||||||
)
|
|
||||||
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
return self._backend._public_key_bytes(
|
|
||||||
encoding, format, self, self._evp_pkey, self._rsa_cdata
|
|
||||||
)
|
|
||||||
|
|
||||||
def verify(
|
|
||||||
self,
|
|
||||||
signature: bytes,
|
|
||||||
data: bytes,
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
|
|
||||||
) -> None:
|
|
||||||
data, algorithm = _calculate_digest_and_algorithm(data, algorithm)
|
|
||||||
_rsa_sig_verify(
|
|
||||||
self._backend, padding, algorithm, self, signature, data
|
|
||||||
)
|
|
||||||
|
|
||||||
def recover_data_from_signature(
|
|
||||||
self,
|
|
||||||
signature: bytes,
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: typing.Optional[hashes.HashAlgorithm],
|
|
||||||
) -> bytes:
|
|
||||||
if isinstance(algorithm, asym_utils.Prehashed):
|
|
||||||
raise TypeError(
|
|
||||||
"Prehashed is only supported in the sign and verify methods. "
|
|
||||||
"It cannot be used with recover_data_from_signature."
|
|
||||||
)
|
|
||||||
return _rsa_sig_recover(
|
|
||||||
self._backend, padding, algorithm, self, signature
|
|
||||||
)
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import hashes
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
def _evp_pkey_derive(backend: "Backend", evp_pkey, peer_public_key) -> bytes:
|
|
||||||
ctx = backend._lib.EVP_PKEY_CTX_new(evp_pkey, backend._ffi.NULL)
|
|
||||||
backend.openssl_assert(ctx != backend._ffi.NULL)
|
|
||||||
ctx = backend._ffi.gc(ctx, backend._lib.EVP_PKEY_CTX_free)
|
|
||||||
res = backend._lib.EVP_PKEY_derive_init(ctx)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
res = backend._lib.EVP_PKEY_derive_set_peer(ctx, peer_public_key._evp_pkey)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
keylen = backend._ffi.new("size_t *")
|
|
||||||
res = backend._lib.EVP_PKEY_derive(ctx, backend._ffi.NULL, keylen)
|
|
||||||
backend.openssl_assert(res == 1)
|
|
||||||
backend.openssl_assert(keylen[0] > 0)
|
|
||||||
buf = backend._ffi.new("unsigned char[]", keylen[0])
|
|
||||||
res = backend._lib.EVP_PKEY_derive(ctx, buf, keylen)
|
|
||||||
if res != 1:
|
|
||||||
errors_with_text = backend._consume_errors_with_text()
|
|
||||||
raise ValueError("Error computing shared key.", errors_with_text)
|
|
||||||
|
|
||||||
return backend._ffi.buffer(buf, keylen[0])[:]
|
|
||||||
|
|
||||||
|
|
||||||
def _calculate_digest_and_algorithm(
|
|
||||||
data: bytes,
|
|
||||||
algorithm: typing.Union[Prehashed, hashes.HashAlgorithm],
|
|
||||||
) -> typing.Tuple[bytes, hashes.HashAlgorithm]:
|
|
||||||
if not isinstance(algorithm, Prehashed):
|
|
||||||
hash_ctx = hashes.Hash(algorithm)
|
|
||||||
hash_ctx.update(data)
|
|
||||||
data = hash_ctx.finalize()
|
|
||||||
else:
|
|
||||||
algorithm = algorithm._algorithm
|
|
||||||
|
|
||||||
if len(data) != algorithm.digest_size:
|
|
||||||
raise ValueError(
|
|
||||||
"The provided data must be the same length as the hash "
|
|
||||||
"algorithm's digest size."
|
|
||||||
)
|
|
||||||
|
|
||||||
return (data, algorithm)
|
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.hazmat.backends.openssl.utils import _evp_pkey_derive
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.x25519 import (
|
|
||||||
X25519PrivateKey,
|
|
||||||
X25519PublicKey,
|
|
||||||
)
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
|
|
||||||
_X25519_KEY_SIZE = 32
|
|
||||||
|
|
||||||
|
|
||||||
class _X25519PublicKey(X25519PublicKey):
|
|
||||||
def __init__(self, backend: "Backend", evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
if (
|
|
||||||
encoding is serialization.Encoding.Raw
|
|
||||||
or format is serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
encoding is not serialization.Encoding.Raw
|
|
||||||
or format is not serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"When using Raw both encoding and format must be Raw"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._raw_public_bytes()
|
|
||||||
|
|
||||||
return self._backend._public_key_bytes(
|
|
||||||
encoding, format, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raw_public_bytes(self) -> bytes:
|
|
||||||
ucharpp = self._backend._ffi.new("unsigned char **")
|
|
||||||
res = self._backend._lib.EVP_PKEY_get1_tls_encodedpoint(
|
|
||||||
self._evp_pkey, ucharpp
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 32)
|
|
||||||
self._backend.openssl_assert(ucharpp[0] != self._backend._ffi.NULL)
|
|
||||||
data = self._backend._ffi.gc(
|
|
||||||
ucharpp[0], self._backend._lib.OPENSSL_free
|
|
||||||
)
|
|
||||||
return self._backend._ffi.buffer(data, res)[:]
|
|
||||||
|
|
||||||
|
|
||||||
class _X25519PrivateKey(X25519PrivateKey):
|
|
||||||
def __init__(self, backend: "Backend", evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
def public_key(self) -> X25519PublicKey:
|
|
||||||
bio = self._backend._create_mem_bio_gc()
|
|
||||||
res = self._backend._lib.i2d_PUBKEY_bio(bio, self._evp_pkey)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
evp_pkey = self._backend._lib.d2i_PUBKEY_bio(
|
|
||||||
bio, self._backend._ffi.NULL
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(evp_pkey != self._backend._ffi.NULL)
|
|
||||||
evp_pkey = self._backend._ffi.gc(
|
|
||||||
evp_pkey, self._backend._lib.EVP_PKEY_free
|
|
||||||
)
|
|
||||||
return _X25519PublicKey(self._backend, evp_pkey)
|
|
||||||
|
|
||||||
def exchange(self, peer_public_key: X25519PublicKey) -> bytes:
|
|
||||||
if not isinstance(peer_public_key, X25519PublicKey):
|
|
||||||
raise TypeError("peer_public_key must be X25519PublicKey.")
|
|
||||||
|
|
||||||
return _evp_pkey_derive(self._backend, self._evp_pkey, peer_public_key)
|
|
||||||
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
if (
|
|
||||||
encoding is serialization.Encoding.Raw
|
|
||||||
or format is serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
format is not serialization.PrivateFormat.Raw
|
|
||||||
or encoding is not serialization.Encoding.Raw
|
|
||||||
or not isinstance(
|
|
||||||
encryption_algorithm, serialization.NoEncryption
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"When using Raw both encoding and format must be Raw "
|
|
||||||
"and encryption_algorithm must be NoEncryption()"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._raw_private_bytes()
|
|
||||||
|
|
||||||
return self._backend._private_key_bytes(
|
|
||||||
encoding, format, encryption_algorithm, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raw_private_bytes(self) -> bytes:
|
|
||||||
# When we drop support for CRYPTOGRAPHY_OPENSSL_LESS_THAN_111 we can
|
|
||||||
# switch this to EVP_PKEY_new_raw_private_key
|
|
||||||
# The trick we use here is serializing to a PKCS8 key and just
|
|
||||||
# using the last 32 bytes, which is the key itself.
|
|
||||||
bio = self._backend._create_mem_bio_gc()
|
|
||||||
res = self._backend._lib.i2d_PKCS8PrivateKey_bio(
|
|
||||||
bio,
|
|
||||||
self._evp_pkey,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
0,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
self._backend._ffi.NULL,
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
pkcs8 = self._backend._read_mem_bio(bio)
|
|
||||||
self._backend.openssl_assert(len(pkcs8) == 48)
|
|
||||||
return pkcs8[-_X25519_KEY_SIZE:]
|
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.hazmat.backends.openssl.utils import _evp_pkey_derive
|
|
||||||
from cryptography.hazmat.primitives import serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.x448 import (
|
|
||||||
X448PrivateKey,
|
|
||||||
X448PublicKey,
|
|
||||||
)
|
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import Backend
|
|
||||||
|
|
||||||
_X448_KEY_SIZE = 56
|
|
||||||
|
|
||||||
|
|
||||||
class _X448PublicKey(X448PublicKey):
|
|
||||||
def __init__(self, backend: "Backend", evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
if (
|
|
||||||
encoding is serialization.Encoding.Raw
|
|
||||||
or format is serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
encoding is not serialization.Encoding.Raw
|
|
||||||
or format is not serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"When using Raw both encoding and format must be Raw"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._raw_public_bytes()
|
|
||||||
|
|
||||||
return self._backend._public_key_bytes(
|
|
||||||
encoding, format, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raw_public_bytes(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new("unsigned char []", _X448_KEY_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", _X448_KEY_SIZE)
|
|
||||||
res = self._backend._lib.EVP_PKEY_get_raw_public_key(
|
|
||||||
self._evp_pkey, buf, buflen
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _X448_KEY_SIZE)
|
|
||||||
return self._backend._ffi.buffer(buf, _X448_KEY_SIZE)[:]
|
|
||||||
|
|
||||||
|
|
||||||
class _X448PrivateKey(X448PrivateKey):
|
|
||||||
def __init__(self, backend: "Backend", evp_pkey):
|
|
||||||
self._backend = backend
|
|
||||||
self._evp_pkey = evp_pkey
|
|
||||||
|
|
||||||
def public_key(self) -> X448PublicKey:
|
|
||||||
buf = self._backend._ffi.new("unsigned char []", _X448_KEY_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", _X448_KEY_SIZE)
|
|
||||||
res = self._backend._lib.EVP_PKEY_get_raw_public_key(
|
|
||||||
self._evp_pkey, buf, buflen
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _X448_KEY_SIZE)
|
|
||||||
public_bytes = self._backend._ffi.buffer(buf)[:]
|
|
||||||
return self._backend.x448_load_public_bytes(public_bytes)
|
|
||||||
|
|
||||||
def exchange(self, peer_public_key: X448PublicKey) -> bytes:
|
|
||||||
if not isinstance(peer_public_key, X448PublicKey):
|
|
||||||
raise TypeError("peer_public_key must be X448PublicKey.")
|
|
||||||
|
|
||||||
return _evp_pkey_derive(self._backend, self._evp_pkey, peer_public_key)
|
|
||||||
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: serialization.Encoding,
|
|
||||||
format: serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
if (
|
|
||||||
encoding is serialization.Encoding.Raw
|
|
||||||
or format is serialization.PublicFormat.Raw
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
format is not serialization.PrivateFormat.Raw
|
|
||||||
or encoding is not serialization.Encoding.Raw
|
|
||||||
or not isinstance(
|
|
||||||
encryption_algorithm, serialization.NoEncryption
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"When using Raw both encoding and format must be Raw "
|
|
||||||
"and encryption_algorithm must be NoEncryption()"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._raw_private_bytes()
|
|
||||||
|
|
||||||
return self._backend._private_key_bytes(
|
|
||||||
encoding, format, encryption_algorithm, self, self._evp_pkey, None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raw_private_bytes(self) -> bytes:
|
|
||||||
buf = self._backend._ffi.new("unsigned char []", _X448_KEY_SIZE)
|
|
||||||
buflen = self._backend._ffi.new("size_t *", _X448_KEY_SIZE)
|
|
||||||
res = self._backend._lib.EVP_PKEY_get_raw_private_key(
|
|
||||||
self._evp_pkey, buf, buflen
|
|
||||||
)
|
|
||||||
self._backend.openssl_assert(res == 1)
|
|
||||||
self._backend.openssl_assert(buflen[0] == _X448_KEY_SIZE)
|
|
||||||
return self._backend._ffi.buffer(buf, _X448_KEY_SIZE)[:]
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from cryptography import utils, x509
|
|
||||||
|
|
||||||
|
|
||||||
# This exists for pyOpenSSL compatibility and SHOULD NOT BE USED
|
|
||||||
# WE WILL REMOVE THIS VERY SOON.
|
|
||||||
def _Certificate(backend, x509) -> x509.Certificate: # noqa: N802
|
|
||||||
warnings.warn(
|
|
||||||
"This version of cryptography contains a temporary pyOpenSSL "
|
|
||||||
"fallback path. Upgrade pyOpenSSL now.",
|
|
||||||
utils.DeprecatedIn35,
|
|
||||||
)
|
|
||||||
return backend._ossl2cert(x509)
|
|
||||||
|
|
||||||
|
|
||||||
# This exists for pyOpenSSL compatibility and SHOULD NOT BE USED
|
|
||||||
# WE WILL REMOVE THIS VERY SOON.
|
|
||||||
def _CertificateSigningRequest( # noqa: N802
|
|
||||||
backend, x509_req
|
|
||||||
) -> x509.CertificateSigningRequest:
|
|
||||||
warnings.warn(
|
|
||||||
"This version of cryptography contains a temporary pyOpenSSL "
|
|
||||||
"fallback path. Upgrade pyOpenSSL now.",
|
|
||||||
utils.DeprecatedIn35,
|
|
||||||
)
|
|
||||||
return backend._ossl2csr(x509_req)
|
|
||||||
|
|
||||||
|
|
||||||
# This exists for pyOpenSSL compatibility and SHOULD NOT BE USED
|
|
||||||
# WE WILL REMOVE THIS VERY SOON.
|
|
||||||
def _CertificateRevocationList( # noqa: N802
|
|
||||||
backend, x509_crl
|
|
||||||
) -> x509.CertificateRevocationList:
|
|
||||||
warnings.warn(
|
|
||||||
"This version of cryptography contains a temporary pyOpenSSL "
|
|
||||||
"fallback path. Upgrade pyOpenSSL now.",
|
|
||||||
utils.DeprecatedIn35,
|
|
||||||
)
|
|
||||||
return backend._ossl2crl(x509_crl)
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
@@ -1,367 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_ec2m() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EC_POINT_get_affine_coordinates_GF2m",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_ssl3_method() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSLv3_method",
|
|
||||||
"SSLv3_client_method",
|
|
||||||
"SSLv3_server_method",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_110_verification_params() -> typing.List[str]:
|
|
||||||
return ["X509_CHECK_FLAG_NEVER_CHECK_SUBJECT"]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_set_cert_cb() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_CTX_set_cert_cb",
|
|
||||||
"SSL_set_cert_cb",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_ssl_st() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_ST_BEFORE",
|
|
||||||
"SSL_ST_OK",
|
|
||||||
"SSL_ST_INIT",
|
|
||||||
"SSL_ST_RENEGOTIATE",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_tls_st() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"TLS_ST_BEFORE",
|
|
||||||
"TLS_ST_OK",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_scrypt() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EVP_PBE_scrypt",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_evp_pkey_dhx() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EVP_PKEY_DHX",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_mem_functions() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"Cryptography_CRYPTO_set_mem_functions",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_x509_store_ctx_get_issuer() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"X509_STORE_get_get_issuer",
|
|
||||||
"X509_STORE_set_get_issuer",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_ed448() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EVP_PKEY_ED448",
|
|
||||||
"NID_ED448",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_ed25519() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"NID_ED25519",
|
|
||||||
"EVP_PKEY_ED25519",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_poly1305() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"NID_poly1305",
|
|
||||||
"EVP_PKEY_POLY1305",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_oneshot_evp_digest_sign_verify() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EVP_DigestSign",
|
|
||||||
"EVP_DigestVerify",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_evp_digestfinal_xof() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EVP_DigestFinalXOF",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_evp_pkey_get_set_tls_encodedpoint() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EVP_PKEY_get1_tls_encodedpoint",
|
|
||||||
"EVP_PKEY_set1_tls_encodedpoint",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_fips() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"FIPS_mode_set",
|
|
||||||
"FIPS_mode",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_psk() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_CTX_use_psk_identity_hint",
|
|
||||||
"SSL_CTX_set_psk_server_callback",
|
|
||||||
"SSL_CTX_set_psk_client_callback",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_psk_tlsv13() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_CTX_set_psk_find_session_callback",
|
|
||||||
"SSL_CTX_set_psk_use_session_callback",
|
|
||||||
"Cryptography_SSL_SESSION_new",
|
|
||||||
"SSL_CIPHER_find",
|
|
||||||
"SSL_SESSION_set1_master_key",
|
|
||||||
"SSL_SESSION_set_cipher",
|
|
||||||
"SSL_SESSION_set_protocol_version",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_custom_ext() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_CTX_add_client_custom_ext",
|
|
||||||
"SSL_CTX_add_server_custom_ext",
|
|
||||||
"SSL_extension_supported",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_openssl_cleanup() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"OPENSSL_cleanup",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_tlsv13() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"TLS1_3_VERSION",
|
|
||||||
"SSL_OP_NO_TLSv1_3",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_tlsv13_functions() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_VERIFY_POST_HANDSHAKE",
|
|
||||||
"SSL_CTX_set_ciphersuites",
|
|
||||||
"SSL_verify_client_post_handshake",
|
|
||||||
"SSL_CTX_set_post_handshake_auth",
|
|
||||||
"SSL_set_post_handshake_auth",
|
|
||||||
"SSL_SESSION_get_max_early_data",
|
|
||||||
"SSL_write_early_data",
|
|
||||||
"SSL_read_early_data",
|
|
||||||
"SSL_CTX_set_max_early_data",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_keylog() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_CTX_set_keylog_callback",
|
|
||||||
"SSL_CTX_get_keylog_callback",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_raw_key() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EVP_PKEY_new_raw_private_key",
|
|
||||||
"EVP_PKEY_new_raw_public_key",
|
|
||||||
"EVP_PKEY_get_raw_private_key",
|
|
||||||
"EVP_PKEY_get_raw_public_key",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_engine() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"ENGINE_by_id",
|
|
||||||
"ENGINE_init",
|
|
||||||
"ENGINE_finish",
|
|
||||||
"ENGINE_get_default_RAND",
|
|
||||||
"ENGINE_set_default_RAND",
|
|
||||||
"ENGINE_unregister_RAND",
|
|
||||||
"ENGINE_ctrl_cmd",
|
|
||||||
"ENGINE_free",
|
|
||||||
"ENGINE_get_name",
|
|
||||||
"Cryptography_add_osrandom_engine",
|
|
||||||
"ENGINE_ctrl_cmd_string",
|
|
||||||
"ENGINE_load_builtin_engines",
|
|
||||||
"ENGINE_load_private_key",
|
|
||||||
"ENGINE_load_public_key",
|
|
||||||
"SSL_CTX_set_client_cert_engine",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_verified_chain() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_get0_verified_chain",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_srtp() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_CTX_set_tlsext_use_srtp",
|
|
||||||
"SSL_set_tlsext_use_srtp",
|
|
||||||
"SSL_get_selected_srtp_profile",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_get_proto_version() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_CTX_get_min_proto_version",
|
|
||||||
"SSL_CTX_get_max_proto_version",
|
|
||||||
"SSL_get_min_proto_version",
|
|
||||||
"SSL_get_max_proto_version",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_providers() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"OSSL_PROVIDER_load",
|
|
||||||
"OSSL_PROVIDER_unload",
|
|
||||||
"ERR_LIB_PROV",
|
|
||||||
"PROV_R_WRONG_FINAL_BLOCK_LENGTH",
|
|
||||||
"PROV_R_BAD_DECRYPT",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_op_no_renegotiation() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_OP_NO_RENEGOTIATION",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_dtls_get_data_mtu() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"DTLS_get_data_mtu",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_300_fips() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EVP_default_properties_is_fips_enabled",
|
|
||||||
"EVP_default_properties_enable_fips",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_ssl_cookie() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SSL_OP_COOKIE_EXCHANGE",
|
|
||||||
"DTLSv1_listen",
|
|
||||||
"SSL_CTX_set_cookie_generate_cb",
|
|
||||||
"SSL_CTX_set_cookie_verify_cb",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_pkcs7_funcs() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"SMIME_write_PKCS7",
|
|
||||||
"PEM_write_bio_PKCS7_stream",
|
|
||||||
"PKCS7_sign_add_signer",
|
|
||||||
"PKCS7_final",
|
|
||||||
"PKCS7_verify",
|
|
||||||
"SMIME_read_PKCS7",
|
|
||||||
"PKCS7_get0_signers",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_bn_flags() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"BN_FLG_CONSTTIME",
|
|
||||||
"BN_set_flags",
|
|
||||||
"BN_prime_checks_for_size",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_evp_pkey_dh() -> typing.List[str]:
|
|
||||||
return [
|
|
||||||
"EVP_PKEY_set1_DH",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_300_evp_cipher() -> typing.List[str]:
|
|
||||||
return ["EVP_CIPHER_fetch", "EVP_CIPHER_free"]
|
|
||||||
|
|
||||||
|
|
||||||
def cryptography_has_unexpected_eof_while_reading() -> typing.List[str]:
|
|
||||||
return ["SSL_R_UNEXPECTED_EOF_WHILE_READING"]
|
|
||||||
|
|
||||||
|
|
||||||
# This is a mapping of
|
|
||||||
# {condition: function-returning-names-dependent-on-that-condition} so we can
|
|
||||||
# loop over them and delete unsupported names at runtime. It will be removed
|
|
||||||
# when cffi supports #if in cdef. We use functions instead of just a dict of
|
|
||||||
# lists so we can use coverage to measure which are used.
|
|
||||||
CONDITIONAL_NAMES = {
|
|
||||||
"Cryptography_HAS_EC2M": cryptography_has_ec2m,
|
|
||||||
"Cryptography_HAS_SSL3_METHOD": cryptography_has_ssl3_method,
|
|
||||||
"Cryptography_HAS_110_VERIFICATION_PARAMS": (
|
|
||||||
cryptography_has_110_verification_params
|
|
||||||
),
|
|
||||||
"Cryptography_HAS_SET_CERT_CB": cryptography_has_set_cert_cb,
|
|
||||||
"Cryptography_HAS_SSL_ST": cryptography_has_ssl_st,
|
|
||||||
"Cryptography_HAS_TLS_ST": cryptography_has_tls_st,
|
|
||||||
"Cryptography_HAS_SCRYPT": cryptography_has_scrypt,
|
|
||||||
"Cryptography_HAS_EVP_PKEY_DHX": cryptography_has_evp_pkey_dhx,
|
|
||||||
"Cryptography_HAS_MEM_FUNCTIONS": cryptography_has_mem_functions,
|
|
||||||
"Cryptography_HAS_X509_STORE_CTX_GET_ISSUER": (
|
|
||||||
cryptography_has_x509_store_ctx_get_issuer
|
|
||||||
),
|
|
||||||
"Cryptography_HAS_ED448": cryptography_has_ed448,
|
|
||||||
"Cryptography_HAS_ED25519": cryptography_has_ed25519,
|
|
||||||
"Cryptography_HAS_POLY1305": cryptography_has_poly1305,
|
|
||||||
"Cryptography_HAS_ONESHOT_EVP_DIGEST_SIGN_VERIFY": (
|
|
||||||
cryptography_has_oneshot_evp_digest_sign_verify
|
|
||||||
),
|
|
||||||
"Cryptography_HAS_EVP_PKEY_get_set_tls_encodedpoint": (
|
|
||||||
cryptography_has_evp_pkey_get_set_tls_encodedpoint
|
|
||||||
),
|
|
||||||
"Cryptography_HAS_FIPS": cryptography_has_fips,
|
|
||||||
"Cryptography_HAS_PSK": cryptography_has_psk,
|
|
||||||
"Cryptography_HAS_PSK_TLSv1_3": cryptography_has_psk_tlsv13,
|
|
||||||
"Cryptography_HAS_CUSTOM_EXT": cryptography_has_custom_ext,
|
|
||||||
"Cryptography_HAS_OPENSSL_CLEANUP": cryptography_has_openssl_cleanup,
|
|
||||||
"Cryptography_HAS_TLSv1_3": cryptography_has_tlsv13,
|
|
||||||
"Cryptography_HAS_TLSv1_3_FUNCTIONS": cryptography_has_tlsv13_functions,
|
|
||||||
"Cryptography_HAS_KEYLOG": cryptography_has_keylog,
|
|
||||||
"Cryptography_HAS_RAW_KEY": cryptography_has_raw_key,
|
|
||||||
"Cryptography_HAS_EVP_DIGESTFINAL_XOF": (
|
|
||||||
cryptography_has_evp_digestfinal_xof
|
|
||||||
),
|
|
||||||
"Cryptography_HAS_ENGINE": cryptography_has_engine,
|
|
||||||
"Cryptography_HAS_VERIFIED_CHAIN": cryptography_has_verified_chain,
|
|
||||||
"Cryptography_HAS_SRTP": cryptography_has_srtp,
|
|
||||||
"Cryptography_HAS_GET_PROTO_VERSION": cryptography_has_get_proto_version,
|
|
||||||
"Cryptography_HAS_PROVIDERS": cryptography_has_providers,
|
|
||||||
"Cryptography_HAS_OP_NO_RENEGOTIATION": (
|
|
||||||
cryptography_has_op_no_renegotiation
|
|
||||||
),
|
|
||||||
"Cryptography_HAS_DTLS_GET_DATA_MTU": cryptography_has_dtls_get_data_mtu,
|
|
||||||
"Cryptography_HAS_300_FIPS": cryptography_has_300_fips,
|
|
||||||
"Cryptography_HAS_SSL_COOKIE": cryptography_has_ssl_cookie,
|
|
||||||
"Cryptography_HAS_PKCS7_FUNCS": cryptography_has_pkcs7_funcs,
|
|
||||||
"Cryptography_HAS_BN_FLAGS": cryptography_has_bn_flags,
|
|
||||||
"Cryptography_HAS_EVP_PKEY_DH": cryptography_has_evp_pkey_dh,
|
|
||||||
"Cryptography_HAS_300_EVP_CIPHER": cryptography_has_300_evp_cipher,
|
|
||||||
"Cryptography_HAS_UNEXPECTED_EOF_WHILE_READING": (
|
|
||||||
cryptography_has_unexpected_eof_while_reading
|
|
||||||
),
|
|
||||||
}
|
|
||||||
@@ -1,230 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import types
|
|
||||||
import typing
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import cryptography
|
|
||||||
from cryptography import utils
|
|
||||||
from cryptography.exceptions import InternalError
|
|
||||||
from cryptography.hazmat.bindings._openssl import ffi, lib
|
|
||||||
from cryptography.hazmat.bindings.openssl._conditional import CONDITIONAL_NAMES
|
|
||||||
|
|
||||||
_OpenSSLErrorWithText = typing.NamedTuple(
|
|
||||||
"_OpenSSLErrorWithText",
|
|
||||||
[("code", int), ("lib", int), ("reason", int), ("reason_text", bytes)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _OpenSSLError:
|
|
||||||
def __init__(self, code: int, lib: int, reason: int):
|
|
||||||
self._code = code
|
|
||||||
self._lib = lib
|
|
||||||
self._reason = reason
|
|
||||||
|
|
||||||
def _lib_reason_match(self, lib: int, reason: int) -> bool:
|
|
||||||
return lib == self.lib and reason == self.reason
|
|
||||||
|
|
||||||
@property
|
|
||||||
def code(self) -> int:
|
|
||||||
return self._code
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lib(self) -> int:
|
|
||||||
return self._lib
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reason(self) -> int:
|
|
||||||
return self._reason
|
|
||||||
|
|
||||||
|
|
||||||
def _consume_errors(lib) -> typing.List[_OpenSSLError]:
|
|
||||||
errors = []
|
|
||||||
while True:
|
|
||||||
code: int = lib.ERR_get_error()
|
|
||||||
if code == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
err_lib: int = lib.ERR_GET_LIB(code)
|
|
||||||
err_reason: int = lib.ERR_GET_REASON(code)
|
|
||||||
|
|
||||||
errors.append(_OpenSSLError(code, err_lib, err_reason))
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
def _errors_with_text(
|
|
||||||
errors: typing.List[_OpenSSLError],
|
|
||||||
) -> typing.List[_OpenSSLErrorWithText]:
|
|
||||||
errors_with_text = []
|
|
||||||
for err in errors:
|
|
||||||
buf = ffi.new("char[]", 256)
|
|
||||||
lib.ERR_error_string_n(err.code, buf, len(buf))
|
|
||||||
err_text_reason: bytes = ffi.string(buf)
|
|
||||||
|
|
||||||
errors_with_text.append(
|
|
||||||
_OpenSSLErrorWithText(
|
|
||||||
err.code, err.lib, err.reason, err_text_reason
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return errors_with_text
|
|
||||||
|
|
||||||
|
|
||||||
def _consume_errors_with_text(lib):
|
|
||||||
return _errors_with_text(_consume_errors(lib))
|
|
||||||
|
|
||||||
|
|
||||||
def _openssl_assert(
|
|
||||||
lib, ok: bool, errors: typing.Optional[typing.List[_OpenSSLError]] = None
|
|
||||||
) -> None:
|
|
||||||
if not ok:
|
|
||||||
if errors is None:
|
|
||||||
errors = _consume_errors(lib)
|
|
||||||
errors_with_text = _errors_with_text(errors)
|
|
||||||
|
|
||||||
raise InternalError(
|
|
||||||
"Unknown OpenSSL error. This error is commonly encountered when "
|
|
||||||
"another library is not cleaning up the OpenSSL error stack. If "
|
|
||||||
"you are using cryptography with another library that uses "
|
|
||||||
"OpenSSL try disabling it before reporting a bug. Otherwise "
|
|
||||||
"please file an issue at https://github.com/pyca/cryptography/"
|
|
||||||
"issues with information on how to reproduce "
|
|
||||||
"this. ({0!r})".format(errors_with_text),
|
|
||||||
errors_with_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_conditional_library(lib, conditional_names):
|
|
||||||
conditional_lib = types.ModuleType("lib")
|
|
||||||
conditional_lib._original_lib = lib # type: ignore[attr-defined]
|
|
||||||
excluded_names = set()
|
|
||||||
for condition, names_cb in conditional_names.items():
|
|
||||||
if not getattr(lib, condition):
|
|
||||||
excluded_names.update(names_cb())
|
|
||||||
|
|
||||||
for attr in dir(lib):
|
|
||||||
if attr not in excluded_names:
|
|
||||||
setattr(conditional_lib, attr, getattr(lib, attr))
|
|
||||||
|
|
||||||
return conditional_lib
|
|
||||||
|
|
||||||
|
|
||||||
class Binding:
|
|
||||||
"""
|
|
||||||
OpenSSL API wrapper.
|
|
||||||
"""
|
|
||||||
|
|
||||||
lib: typing.ClassVar = None
|
|
||||||
ffi = ffi
|
|
||||||
_lib_loaded = False
|
|
||||||
_init_lock = threading.Lock()
|
|
||||||
_legacy_provider: typing.Any = None
|
|
||||||
_default_provider: typing.Any = None
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._ensure_ffi_initialized()
|
|
||||||
|
|
||||||
def _enable_fips(self) -> None:
|
|
||||||
# This function enables FIPS mode for OpenSSL 3.0.0 on installs that
|
|
||||||
# have the FIPS provider installed properly.
|
|
||||||
_openssl_assert(self.lib, self.lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER)
|
|
||||||
self._base_provider = self.lib.OSSL_PROVIDER_load(
|
|
||||||
self.ffi.NULL, b"base"
|
|
||||||
)
|
|
||||||
_openssl_assert(self.lib, self._base_provider != self.ffi.NULL)
|
|
||||||
self.lib._fips_provider = self.lib.OSSL_PROVIDER_load(
|
|
||||||
self.ffi.NULL, b"fips"
|
|
||||||
)
|
|
||||||
_openssl_assert(self.lib, self.lib._fips_provider != self.ffi.NULL)
|
|
||||||
|
|
||||||
res = self.lib.EVP_default_properties_enable_fips(self.ffi.NULL, 1)
|
|
||||||
_openssl_assert(self.lib, res == 1)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _register_osrandom_engine(cls):
|
|
||||||
# Clear any errors extant in the queue before we start. In many
|
|
||||||
# scenarios other things may be interacting with OpenSSL in the same
|
|
||||||
# process space and it has proven untenable to assume that they will
|
|
||||||
# reliably clear the error queue. Once we clear it here we will
|
|
||||||
# error on any subsequent unexpected item in the stack.
|
|
||||||
cls.lib.ERR_clear_error()
|
|
||||||
if cls.lib.CRYPTOGRAPHY_NEEDS_OSRANDOM_ENGINE:
|
|
||||||
result = cls.lib.Cryptography_add_osrandom_engine()
|
|
||||||
_openssl_assert(cls.lib, result in (1, 2))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _ensure_ffi_initialized(cls):
|
|
||||||
with cls._init_lock:
|
|
||||||
if not cls._lib_loaded:
|
|
||||||
cls.lib = build_conditional_library(lib, CONDITIONAL_NAMES)
|
|
||||||
cls._lib_loaded = True
|
|
||||||
cls._register_osrandom_engine()
|
|
||||||
# As of OpenSSL 3.0.0 we must register a legacy cipher provider
|
|
||||||
# to get RC2 (needed for junk asymmetric private key
|
|
||||||
# serialization), RC4, Blowfish, IDEA, SEED, etc. These things
|
|
||||||
# are ugly legacy, but we aren't going to get rid of them
|
|
||||||
# any time soon.
|
|
||||||
if cls.lib.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER:
|
|
||||||
cls._legacy_provider = cls.lib.OSSL_PROVIDER_load(
|
|
||||||
cls.ffi.NULL, b"legacy"
|
|
||||||
)
|
|
||||||
_openssl_assert(
|
|
||||||
cls.lib, cls._legacy_provider != cls.ffi.NULL
|
|
||||||
)
|
|
||||||
cls._default_provider = cls.lib.OSSL_PROVIDER_load(
|
|
||||||
cls.ffi.NULL, b"default"
|
|
||||||
)
|
|
||||||
_openssl_assert(
|
|
||||||
cls.lib, cls._default_provider != cls.ffi.NULL
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def init_static_locks(cls):
|
|
||||||
cls._ensure_ffi_initialized()
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_openssl_version(lib):
|
|
||||||
if (
|
|
||||||
lib.CRYPTOGRAPHY_OPENSSL_LESS_THAN_111
|
|
||||||
and not lib.CRYPTOGRAPHY_IS_LIBRESSL
|
|
||||||
and not lib.CRYPTOGRAPHY_IS_BORINGSSL
|
|
||||||
):
|
|
||||||
warnings.warn(
|
|
||||||
"OpenSSL version 1.1.0 is no longer supported by the OpenSSL "
|
|
||||||
"project, please upgrade. The next release of cryptography will "
|
|
||||||
"be the last to support compiling with OpenSSL 1.1.0.",
|
|
||||||
utils.DeprecatedIn37,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_package_version(version):
|
|
||||||
# Occasionally we run into situations where the version of the Python
|
|
||||||
# package does not match the version of the shared object that is loaded.
|
|
||||||
# This may occur in environments where multiple versions of cryptography
|
|
||||||
# are installed and available in the python path. To avoid errors cropping
|
|
||||||
# up later this code checks that the currently imported package and the
|
|
||||||
# shared object that were loaded have the same version and raise an
|
|
||||||
# ImportError if they do not
|
|
||||||
so_package_version = ffi.string(lib.CRYPTOGRAPHY_PACKAGE_VERSION)
|
|
||||||
if version.encode("ascii") != so_package_version:
|
|
||||||
raise ImportError(
|
|
||||||
"The version of cryptography does not match the loaded "
|
|
||||||
"shared object. This can happen if you have multiple copies of "
|
|
||||||
"cryptography installed in your Python path. Please try creating "
|
|
||||||
"a new virtual environment to resolve this issue. "
|
|
||||||
"Loaded python version: {}, shared object version: {}".format(
|
|
||||||
version, so_package_version
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_verify_package_version(cryptography.__version__)
|
|
||||||
|
|
||||||
Binding.init_static_locks()
|
|
||||||
|
|
||||||
_verify_openssl_version(Binding.lib)
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import abc
|
|
||||||
|
|
||||||
|
|
||||||
# This exists to break an import cycle. It is normally accessible from the
|
|
||||||
# asymmetric padding module.
|
|
||||||
|
|
||||||
|
|
||||||
class AsymmetricPadding(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractproperty
|
|
||||||
def name(self) -> str:
|
|
||||||
"""
|
|
||||||
A string naming this padding (e.g. "PSS", "PKCS1").
|
|
||||||
"""
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import typing
|
|
||||||
|
|
||||||
|
|
||||||
# This exists to break an import cycle. It is normally accessible from the
|
|
||||||
# ciphers module.
|
|
||||||
|
|
||||||
|
|
||||||
class CipherAlgorithm(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractproperty
|
|
||||||
def name(self) -> str:
|
|
||||||
"""
|
|
||||||
A string naming this mode (e.g. "AES", "Camellia").
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_sizes(self) -> typing.FrozenSet[int]:
|
|
||||||
"""
|
|
||||||
Valid key sizes for this algorithm in bits
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
The size of the key being used as an integer in bits (e.g. 128, 256).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class BlockCipherAlgorithm(metaclass=abc.ABCMeta):
|
|
||||||
key: bytes
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def block_size(self) -> int:
|
|
||||||
"""
|
|
||||||
The size of a block as an integer in bits (e.g. 64, 128).
|
|
||||||
"""
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import abc
|
|
||||||
|
|
||||||
from cryptography import utils
|
|
||||||
|
|
||||||
# This exists to break an import cycle. These classes are normally accessible
|
|
||||||
# from the serialization module.
|
|
||||||
|
|
||||||
|
|
||||||
class Encoding(utils.Enum):
|
|
||||||
PEM = "PEM"
|
|
||||||
DER = "DER"
|
|
||||||
OpenSSH = "OpenSSH"
|
|
||||||
Raw = "Raw"
|
|
||||||
X962 = "ANSI X9.62"
|
|
||||||
SMIME = "S/MIME"
|
|
||||||
|
|
||||||
|
|
||||||
class PrivateFormat(utils.Enum):
|
|
||||||
PKCS8 = "PKCS8"
|
|
||||||
TraditionalOpenSSL = "TraditionalOpenSSL"
|
|
||||||
Raw = "Raw"
|
|
||||||
OpenSSH = "OpenSSH"
|
|
||||||
|
|
||||||
|
|
||||||
class PublicFormat(utils.Enum):
|
|
||||||
SubjectPublicKeyInfo = "X.509 subjectPublicKeyInfo with PKCS#1"
|
|
||||||
PKCS1 = "Raw PKCS#1"
|
|
||||||
OpenSSH = "OpenSSH"
|
|
||||||
Raw = "Raw"
|
|
||||||
CompressedPoint = "X9.62 Compressed Point"
|
|
||||||
UncompressedPoint = "X9.62 Uncompressed Point"
|
|
||||||
|
|
||||||
|
|
||||||
class ParameterFormat(utils.Enum):
|
|
||||||
PKCS3 = "PKCS3"
|
|
||||||
|
|
||||||
|
|
||||||
class KeySerializationEncryption(metaclass=abc.ABCMeta):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BestAvailableEncryption(KeySerializationEncryption):
|
|
||||||
def __init__(self, password: bytes):
|
|
||||||
if not isinstance(password, bytes) or len(password) == 0:
|
|
||||||
raise ValueError("Password must be 1 or more bytes.")
|
|
||||||
|
|
||||||
self.password = password
|
|
||||||
|
|
||||||
|
|
||||||
class NoEncryption(KeySerializationEncryption):
|
|
||||||
pass
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
@@ -1,250 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import _serialization
|
|
||||||
|
|
||||||
|
|
||||||
_MIN_MODULUS_SIZE = 512
|
|
||||||
|
|
||||||
|
|
||||||
def generate_parameters(
|
|
||||||
generator: int, key_size: int, backend: typing.Any = None
|
|
||||||
) -> "DHParameters":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend as ossl
|
|
||||||
|
|
||||||
return ossl.generate_dh_parameters(generator, key_size)
|
|
||||||
|
|
||||||
|
|
||||||
class DHParameterNumbers:
|
|
||||||
def __init__(self, p: int, g: int, q: typing.Optional[int] = None) -> None:
|
|
||||||
if not isinstance(p, int) or not isinstance(g, int):
|
|
||||||
raise TypeError("p and g must be integers")
|
|
||||||
if q is not None and not isinstance(q, int):
|
|
||||||
raise TypeError("q must be integer or None")
|
|
||||||
|
|
||||||
if g < 2:
|
|
||||||
raise ValueError("DH generator must be 2 or greater")
|
|
||||||
|
|
||||||
if p.bit_length() < _MIN_MODULUS_SIZE:
|
|
||||||
raise ValueError(
|
|
||||||
"p (modulus) must be at least {}-bit".format(_MIN_MODULUS_SIZE)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._p = p
|
|
||||||
self._g = g
|
|
||||||
self._q = q
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, DHParameterNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return (
|
|
||||||
self._p == other._p and self._g == other._g and self._q == other._q
|
|
||||||
)
|
|
||||||
|
|
||||||
def parameters(self, backend: typing.Any = None) -> "DHParameters":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_dh_parameter_numbers(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def p(self) -> int:
|
|
||||||
return self._p
|
|
||||||
|
|
||||||
@property
|
|
||||||
def g(self) -> int:
|
|
||||||
return self._g
|
|
||||||
|
|
||||||
@property
|
|
||||||
def q(self) -> typing.Optional[int]:
|
|
||||||
return self._q
|
|
||||||
|
|
||||||
|
|
||||||
class DHPublicNumbers:
|
|
||||||
def __init__(self, y: int, parameter_numbers: DHParameterNumbers) -> None:
|
|
||||||
if not isinstance(y, int):
|
|
||||||
raise TypeError("y must be an integer.")
|
|
||||||
|
|
||||||
if not isinstance(parameter_numbers, DHParameterNumbers):
|
|
||||||
raise TypeError(
|
|
||||||
"parameters must be an instance of DHParameterNumbers."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._y = y
|
|
||||||
self._parameter_numbers = parameter_numbers
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, DHPublicNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return (
|
|
||||||
self._y == other._y
|
|
||||||
and self._parameter_numbers == other._parameter_numbers
|
|
||||||
)
|
|
||||||
|
|
||||||
def public_key(self, backend: typing.Any = None) -> "DHPublicKey":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_dh_public_numbers(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def y(self) -> int:
|
|
||||||
return self._y
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameter_numbers(self) -> DHParameterNumbers:
|
|
||||||
return self._parameter_numbers
|
|
||||||
|
|
||||||
|
|
||||||
class DHPrivateNumbers:
|
|
||||||
def __init__(self, x: int, public_numbers: DHPublicNumbers) -> None:
|
|
||||||
if not isinstance(x, int):
|
|
||||||
raise TypeError("x must be an integer.")
|
|
||||||
|
|
||||||
if not isinstance(public_numbers, DHPublicNumbers):
|
|
||||||
raise TypeError(
|
|
||||||
"public_numbers must be an instance of " "DHPublicNumbers."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._x = x
|
|
||||||
self._public_numbers = public_numbers
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, DHPrivateNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return (
|
|
||||||
self._x == other._x
|
|
||||||
and self._public_numbers == other._public_numbers
|
|
||||||
)
|
|
||||||
|
|
||||||
def private_key(self, backend: typing.Any = None) -> "DHPrivateKey":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_dh_private_numbers(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def public_numbers(self) -> DHPublicNumbers:
|
|
||||||
return self._public_numbers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def x(self) -> int:
|
|
||||||
return self._x
|
|
||||||
|
|
||||||
|
|
||||||
class DHParameters(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def generate_private_key(self) -> "DHPrivateKey":
|
|
||||||
"""
|
|
||||||
Generates and returns a DHPrivateKey.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def parameter_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.ParameterFormat,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Returns the parameters serialized as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def parameter_numbers(self) -> DHParameterNumbers:
|
|
||||||
"""
|
|
||||||
Returns a DHParameterNumbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DHParametersWithSerialization = DHParameters
|
|
||||||
|
|
||||||
|
|
||||||
class DHPublicKey(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
The bit length of the prime modulus.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def parameters(self) -> DHParameters:
|
|
||||||
"""
|
|
||||||
The DHParameters object associated with this public key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_numbers(self) -> DHPublicNumbers:
|
|
||||||
"""
|
|
||||||
Returns a DHPublicNumbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Returns the key serialized as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DHPublicKeyWithSerialization = DHPublicKey
|
|
||||||
|
|
||||||
|
|
||||||
class DHPrivateKey(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
The bit length of the prime modulus.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_key(self) -> DHPublicKey:
|
|
||||||
"""
|
|
||||||
The DHPublicKey associated with this private key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def parameters(self) -> DHParameters:
|
|
||||||
"""
|
|
||||||
The DHParameters object associated with this private key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def exchange(self, peer_public_key: DHPublicKey) -> bytes:
|
|
||||||
"""
|
|
||||||
Given peer's DHPublicKey, carry out the key exchange and
|
|
||||||
return shared key as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_numbers(self) -> DHPrivateNumbers:
|
|
||||||
"""
|
|
||||||
Returns a DHPrivateNumbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: _serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Returns the key serialized as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DHPrivateKeyWithSerialization = DHPrivateKey
|
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import _serialization, hashes
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import (
|
|
||||||
utils as asym_utils,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DSAParameters(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def generate_private_key(self) -> "DSAPrivateKey":
|
|
||||||
"""
|
|
||||||
Generates and returns a DSAPrivateKey.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def parameter_numbers(self) -> "DSAParameterNumbers":
|
|
||||||
"""
|
|
||||||
Returns a DSAParameterNumbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DSAParametersWithNumbers = DSAParameters
|
|
||||||
|
|
||||||
|
|
||||||
class DSAPrivateKey(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
The bit length of the prime modulus.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_key(self) -> "DSAPublicKey":
|
|
||||||
"""
|
|
||||||
The DSAPublicKey associated with this private key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def parameters(self) -> DSAParameters:
|
|
||||||
"""
|
|
||||||
The DSAParameters object associated with this private key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def sign(
|
|
||||||
self,
|
|
||||||
data: bytes,
|
|
||||||
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Signs the data
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_numbers(self) -> "DSAPrivateNumbers":
|
|
||||||
"""
|
|
||||||
Returns a DSAPrivateNumbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: _serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Returns the key serialized as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DSAPrivateKeyWithSerialization = DSAPrivateKey
|
|
||||||
|
|
||||||
|
|
||||||
class DSAPublicKey(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
The bit length of the prime modulus.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def parameters(self) -> DSAParameters:
|
|
||||||
"""
|
|
||||||
The DSAParameters object associated with this public key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_numbers(self) -> "DSAPublicNumbers":
|
|
||||||
"""
|
|
||||||
Returns a DSAPublicNumbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Returns the key serialized as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def verify(
|
|
||||||
self,
|
|
||||||
signature: bytes,
|
|
||||||
data: bytes,
|
|
||||||
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Verifies the signature of the data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DSAPublicKeyWithSerialization = DSAPublicKey
|
|
||||||
|
|
||||||
|
|
||||||
class DSAParameterNumbers:
|
|
||||||
def __init__(self, p: int, q: int, g: int):
|
|
||||||
if (
|
|
||||||
not isinstance(p, int)
|
|
||||||
or not isinstance(q, int)
|
|
||||||
or not isinstance(g, int)
|
|
||||||
):
|
|
||||||
raise TypeError(
|
|
||||||
"DSAParameterNumbers p, q, and g arguments must be integers."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._p = p
|
|
||||||
self._q = q
|
|
||||||
self._g = g
|
|
||||||
|
|
||||||
@property
|
|
||||||
def p(self) -> int:
|
|
||||||
return self._p
|
|
||||||
|
|
||||||
@property
|
|
||||||
def q(self) -> int:
|
|
||||||
return self._q
|
|
||||||
|
|
||||||
@property
|
|
||||||
def g(self) -> int:
|
|
||||||
return self._g
|
|
||||||
|
|
||||||
def parameters(self, backend: typing.Any = None) -> DSAParameters:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_dsa_parameter_numbers(self)
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, DSAParameterNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return self.p == other.p and self.q == other.q and self.g == other.g
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
"<DSAParameterNumbers(p={self.p}, q={self.q}, "
|
|
||||||
"g={self.g})>".format(self=self)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DSAPublicNumbers:
|
|
||||||
def __init__(self, y: int, parameter_numbers: DSAParameterNumbers):
|
|
||||||
if not isinstance(y, int):
|
|
||||||
raise TypeError("DSAPublicNumbers y argument must be an integer.")
|
|
||||||
|
|
||||||
if not isinstance(parameter_numbers, DSAParameterNumbers):
|
|
||||||
raise TypeError(
|
|
||||||
"parameter_numbers must be a DSAParameterNumbers instance."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._y = y
|
|
||||||
self._parameter_numbers = parameter_numbers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def y(self) -> int:
|
|
||||||
return self._y
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameter_numbers(self) -> DSAParameterNumbers:
|
|
||||||
return self._parameter_numbers
|
|
||||||
|
|
||||||
def public_key(self, backend: typing.Any = None) -> DSAPublicKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_dsa_public_numbers(self)
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, DSAPublicNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return (
|
|
||||||
self.y == other.y
|
|
||||||
and self.parameter_numbers == other.parameter_numbers
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
"<DSAPublicNumbers(y={self.y}, "
|
|
||||||
"parameter_numbers={self.parameter_numbers})>".format(self=self)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DSAPrivateNumbers:
|
|
||||||
def __init__(self, x: int, public_numbers: DSAPublicNumbers):
|
|
||||||
if not isinstance(x, int):
|
|
||||||
raise TypeError("DSAPrivateNumbers x argument must be an integer.")
|
|
||||||
|
|
||||||
if not isinstance(public_numbers, DSAPublicNumbers):
|
|
||||||
raise TypeError(
|
|
||||||
"public_numbers must be a DSAPublicNumbers instance."
|
|
||||||
)
|
|
||||||
self._public_numbers = public_numbers
|
|
||||||
self._x = x
|
|
||||||
|
|
||||||
@property
|
|
||||||
def x(self) -> int:
|
|
||||||
return self._x
|
|
||||||
|
|
||||||
@property
|
|
||||||
def public_numbers(self) -> DSAPublicNumbers:
|
|
||||||
return self._public_numbers
|
|
||||||
|
|
||||||
def private_key(self, backend: typing.Any = None) -> DSAPrivateKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_dsa_private_numbers(self)
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, DSAPrivateNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return (
|
|
||||||
self.x == other.x and self.public_numbers == other.public_numbers
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_parameters(
|
|
||||||
key_size: int, backend: typing.Any = None
|
|
||||||
) -> DSAParameters:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend as ossl
|
|
||||||
|
|
||||||
return ossl.generate_dsa_parameters(key_size)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_private_key(
|
|
||||||
key_size: int, backend: typing.Any = None
|
|
||||||
) -> DSAPrivateKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend as ossl
|
|
||||||
|
|
||||||
return ossl.generate_dsa_private_key_and_parameters(key_size)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_dsa_parameters(parameters: DSAParameterNumbers) -> None:
|
|
||||||
if parameters.p.bit_length() not in [1024, 2048, 3072, 4096]:
|
|
||||||
raise ValueError(
|
|
||||||
"p must be exactly 1024, 2048, 3072, or 4096 bits long"
|
|
||||||
)
|
|
||||||
if parameters.q.bit_length() not in [160, 224, 256]:
|
|
||||||
raise ValueError("q must be exactly 160, 224, or 256 bits long")
|
|
||||||
|
|
||||||
if not (1 < parameters.g < parameters.p):
|
|
||||||
raise ValueError("g, p don't satisfy 1 < g < p.")
|
|
||||||
|
|
||||||
|
|
||||||
def _check_dsa_private_numbers(numbers: DSAPrivateNumbers) -> None:
|
|
||||||
parameters = numbers.public_numbers.parameter_numbers
|
|
||||||
_check_dsa_parameters(parameters)
|
|
||||||
if numbers.x <= 0 or numbers.x >= parameters.q:
|
|
||||||
raise ValueError("x must be > 0 and < q.")
|
|
||||||
|
|
||||||
if numbers.public_numbers.y != pow(parameters.g, numbers.x, parameters.p):
|
|
||||||
raise ValueError("y must be equal to (g ** x % p).")
|
|
||||||
@@ -1,523 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import typing
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from cryptography import utils
|
|
||||||
from cryptography.hazmat._oid import ObjectIdentifier
|
|
||||||
from cryptography.hazmat.primitives import _serialization, hashes
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import (
|
|
||||||
utils as asym_utils,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EllipticCurveOID:
|
|
||||||
SECP192R1 = ObjectIdentifier("1.2.840.10045.3.1.1")
|
|
||||||
SECP224R1 = ObjectIdentifier("1.3.132.0.33")
|
|
||||||
SECP256K1 = ObjectIdentifier("1.3.132.0.10")
|
|
||||||
SECP256R1 = ObjectIdentifier("1.2.840.10045.3.1.7")
|
|
||||||
SECP384R1 = ObjectIdentifier("1.3.132.0.34")
|
|
||||||
SECP521R1 = ObjectIdentifier("1.3.132.0.35")
|
|
||||||
BRAINPOOLP256R1 = ObjectIdentifier("1.3.36.3.3.2.8.1.1.7")
|
|
||||||
BRAINPOOLP384R1 = ObjectIdentifier("1.3.36.3.3.2.8.1.1.11")
|
|
||||||
BRAINPOOLP512R1 = ObjectIdentifier("1.3.36.3.3.2.8.1.1.13")
|
|
||||||
SECT163K1 = ObjectIdentifier("1.3.132.0.1")
|
|
||||||
SECT163R2 = ObjectIdentifier("1.3.132.0.15")
|
|
||||||
SECT233K1 = ObjectIdentifier("1.3.132.0.26")
|
|
||||||
SECT233R1 = ObjectIdentifier("1.3.132.0.27")
|
|
||||||
SECT283K1 = ObjectIdentifier("1.3.132.0.16")
|
|
||||||
SECT283R1 = ObjectIdentifier("1.3.132.0.17")
|
|
||||||
SECT409K1 = ObjectIdentifier("1.3.132.0.36")
|
|
||||||
SECT409R1 = ObjectIdentifier("1.3.132.0.37")
|
|
||||||
SECT571K1 = ObjectIdentifier("1.3.132.0.38")
|
|
||||||
SECT571R1 = ObjectIdentifier("1.3.132.0.39")
|
|
||||||
|
|
||||||
|
|
||||||
class EllipticCurve(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractproperty
|
|
||||||
def name(self) -> str:
|
|
||||||
"""
|
|
||||||
The name of the curve. e.g. secp256r1.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
Bit size of a secret scalar for the curve.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class EllipticCurveSignatureAlgorithm(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractproperty
|
|
||||||
def algorithm(
|
|
||||||
self,
|
|
||||||
) -> typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm]:
|
|
||||||
"""
|
|
||||||
The digest algorithm used with this signature.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class EllipticCurvePrivateKey(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def exchange(
|
|
||||||
self, algorithm: "ECDH", peer_public_key: "EllipticCurvePublicKey"
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Performs a key exchange operation using the provided algorithm with the
|
|
||||||
provided peer's public key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_key(self) -> "EllipticCurvePublicKey":
|
|
||||||
"""
|
|
||||||
The EllipticCurvePublicKey for this private key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def curve(self) -> EllipticCurve:
|
|
||||||
"""
|
|
||||||
The EllipticCurve that this key is on.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
Bit size of a secret scalar for the curve.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def sign(
|
|
||||||
self,
|
|
||||||
data: bytes,
|
|
||||||
signature_algorithm: EllipticCurveSignatureAlgorithm,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Signs the data
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_numbers(self) -> "EllipticCurvePrivateNumbers":
|
|
||||||
"""
|
|
||||||
Returns an EllipticCurvePrivateNumbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: _serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Returns the key serialized as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
EllipticCurvePrivateKeyWithSerialization = EllipticCurvePrivateKey
|
|
||||||
|
|
||||||
|
|
||||||
class EllipticCurvePublicKey(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractproperty
|
|
||||||
def curve(self) -> EllipticCurve:
|
|
||||||
"""
|
|
||||||
The EllipticCurve that this key is on.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
Bit size of a secret scalar for the curve.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_numbers(self) -> "EllipticCurvePublicNumbers":
|
|
||||||
"""
|
|
||||||
Returns an EllipticCurvePublicNumbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Returns the key serialized as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def verify(
|
|
||||||
self,
|
|
||||||
signature: bytes,
|
|
||||||
data: bytes,
|
|
||||||
signature_algorithm: EllipticCurveSignatureAlgorithm,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Verifies the signature of the data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_encoded_point(
|
|
||||||
cls, curve: EllipticCurve, data: bytes
|
|
||||||
) -> "EllipticCurvePublicKey":
|
|
||||||
utils._check_bytes("data", data)
|
|
||||||
|
|
||||||
if not isinstance(curve, EllipticCurve):
|
|
||||||
raise TypeError("curve must be an EllipticCurve instance")
|
|
||||||
|
|
||||||
if len(data) == 0:
|
|
||||||
raise ValueError("data must not be an empty byte string")
|
|
||||||
|
|
||||||
if data[0] not in [0x02, 0x03, 0x04]:
|
|
||||||
raise ValueError("Unsupported elliptic curve point type")
|
|
||||||
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
|
|
||||||
return backend.load_elliptic_curve_public_bytes(curve, data)
|
|
||||||
|
|
||||||
|
|
||||||
EllipticCurvePublicKeyWithSerialization = EllipticCurvePublicKey
|
|
||||||
|
|
||||||
|
|
||||||
class SECT571R1(EllipticCurve):
|
|
||||||
name = "sect571r1"
|
|
||||||
key_size = 570
|
|
||||||
|
|
||||||
|
|
||||||
class SECT409R1(EllipticCurve):
|
|
||||||
name = "sect409r1"
|
|
||||||
key_size = 409
|
|
||||||
|
|
||||||
|
|
||||||
class SECT283R1(EllipticCurve):
|
|
||||||
name = "sect283r1"
|
|
||||||
key_size = 283
|
|
||||||
|
|
||||||
|
|
||||||
class SECT233R1(EllipticCurve):
|
|
||||||
name = "sect233r1"
|
|
||||||
key_size = 233
|
|
||||||
|
|
||||||
|
|
||||||
class SECT163R2(EllipticCurve):
|
|
||||||
name = "sect163r2"
|
|
||||||
key_size = 163
|
|
||||||
|
|
||||||
|
|
||||||
class SECT571K1(EllipticCurve):
|
|
||||||
name = "sect571k1"
|
|
||||||
key_size = 571
|
|
||||||
|
|
||||||
|
|
||||||
class SECT409K1(EllipticCurve):
|
|
||||||
name = "sect409k1"
|
|
||||||
key_size = 409
|
|
||||||
|
|
||||||
|
|
||||||
class SECT283K1(EllipticCurve):
|
|
||||||
name = "sect283k1"
|
|
||||||
key_size = 283
|
|
||||||
|
|
||||||
|
|
||||||
class SECT233K1(EllipticCurve):
|
|
||||||
name = "sect233k1"
|
|
||||||
key_size = 233
|
|
||||||
|
|
||||||
|
|
||||||
class SECT163K1(EllipticCurve):
|
|
||||||
name = "sect163k1"
|
|
||||||
key_size = 163
|
|
||||||
|
|
||||||
|
|
||||||
class SECP521R1(EllipticCurve):
|
|
||||||
name = "secp521r1"
|
|
||||||
key_size = 521
|
|
||||||
|
|
||||||
|
|
||||||
class SECP384R1(EllipticCurve):
|
|
||||||
name = "secp384r1"
|
|
||||||
key_size = 384
|
|
||||||
|
|
||||||
|
|
||||||
class SECP256R1(EllipticCurve):
|
|
||||||
name = "secp256r1"
|
|
||||||
key_size = 256
|
|
||||||
|
|
||||||
|
|
||||||
class SECP256K1(EllipticCurve):
|
|
||||||
name = "secp256k1"
|
|
||||||
key_size = 256
|
|
||||||
|
|
||||||
|
|
||||||
class SECP224R1(EllipticCurve):
|
|
||||||
name = "secp224r1"
|
|
||||||
key_size = 224
|
|
||||||
|
|
||||||
|
|
||||||
class SECP192R1(EllipticCurve):
|
|
||||||
name = "secp192r1"
|
|
||||||
key_size = 192
|
|
||||||
|
|
||||||
|
|
||||||
class BrainpoolP256R1(EllipticCurve):
|
|
||||||
name = "brainpoolP256r1"
|
|
||||||
key_size = 256
|
|
||||||
|
|
||||||
|
|
||||||
class BrainpoolP384R1(EllipticCurve):
|
|
||||||
name = "brainpoolP384r1"
|
|
||||||
key_size = 384
|
|
||||||
|
|
||||||
|
|
||||||
class BrainpoolP512R1(EllipticCurve):
|
|
||||||
name = "brainpoolP512r1"
|
|
||||||
key_size = 512
|
|
||||||
|
|
||||||
|
|
||||||
_CURVE_TYPES: typing.Dict[str, typing.Type[EllipticCurve]] = {
|
|
||||||
"prime192v1": SECP192R1,
|
|
||||||
"prime256v1": SECP256R1,
|
|
||||||
"secp192r1": SECP192R1,
|
|
||||||
"secp224r1": SECP224R1,
|
|
||||||
"secp256r1": SECP256R1,
|
|
||||||
"secp384r1": SECP384R1,
|
|
||||||
"secp521r1": SECP521R1,
|
|
||||||
"secp256k1": SECP256K1,
|
|
||||||
"sect163k1": SECT163K1,
|
|
||||||
"sect233k1": SECT233K1,
|
|
||||||
"sect283k1": SECT283K1,
|
|
||||||
"sect409k1": SECT409K1,
|
|
||||||
"sect571k1": SECT571K1,
|
|
||||||
"sect163r2": SECT163R2,
|
|
||||||
"sect233r1": SECT233R1,
|
|
||||||
"sect283r1": SECT283R1,
|
|
||||||
"sect409r1": SECT409R1,
|
|
||||||
"sect571r1": SECT571R1,
|
|
||||||
"brainpoolP256r1": BrainpoolP256R1,
|
|
||||||
"brainpoolP384r1": BrainpoolP384R1,
|
|
||||||
"brainpoolP512r1": BrainpoolP512R1,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ECDSA(EllipticCurveSignatureAlgorithm):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
|
|
||||||
):
|
|
||||||
self._algorithm = algorithm
|
|
||||||
|
|
||||||
@property
|
|
||||||
def algorithm(
|
|
||||||
self,
|
|
||||||
) -> typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm]:
|
|
||||||
return self._algorithm
|
|
||||||
|
|
||||||
|
|
||||||
def generate_private_key(
|
|
||||||
curve: EllipticCurve, backend: typing.Any = None
|
|
||||||
) -> EllipticCurvePrivateKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend as ossl
|
|
||||||
|
|
||||||
return ossl.generate_elliptic_curve_private_key(curve)
|
|
||||||
|
|
||||||
|
|
||||||
def derive_private_key(
|
|
||||||
private_value: int,
|
|
||||||
curve: EllipticCurve,
|
|
||||||
backend: typing.Any = None,
|
|
||||||
) -> EllipticCurvePrivateKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend as ossl
|
|
||||||
|
|
||||||
if not isinstance(private_value, int):
|
|
||||||
raise TypeError("private_value must be an integer type.")
|
|
||||||
|
|
||||||
if private_value <= 0:
|
|
||||||
raise ValueError("private_value must be a positive integer.")
|
|
||||||
|
|
||||||
if not isinstance(curve, EllipticCurve):
|
|
||||||
raise TypeError("curve must provide the EllipticCurve interface.")
|
|
||||||
|
|
||||||
return ossl.derive_elliptic_curve_private_key(private_value, curve)
|
|
||||||
|
|
||||||
|
|
||||||
class EllipticCurvePublicNumbers:
|
|
||||||
def __init__(self, x: int, y: int, curve: EllipticCurve):
|
|
||||||
if not isinstance(x, int) or not isinstance(y, int):
|
|
||||||
raise TypeError("x and y must be integers.")
|
|
||||||
|
|
||||||
if not isinstance(curve, EllipticCurve):
|
|
||||||
raise TypeError("curve must provide the EllipticCurve interface.")
|
|
||||||
|
|
||||||
self._y = y
|
|
||||||
self._x = x
|
|
||||||
self._curve = curve
|
|
||||||
|
|
||||||
def public_key(self, backend: typing.Any = None) -> EllipticCurvePublicKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_elliptic_curve_public_numbers(self)
|
|
||||||
|
|
||||||
def encode_point(self) -> bytes:
|
|
||||||
warnings.warn(
|
|
||||||
"encode_point has been deprecated on EllipticCurvePublicNumbers"
|
|
||||||
" and will be removed in a future version. Please use "
|
|
||||||
"EllipticCurvePublicKey.public_bytes to obtain both "
|
|
||||||
"compressed and uncompressed point encoding.",
|
|
||||||
utils.PersistentlyDeprecated2019,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
# key_size is in bits. Convert to bytes and round up
|
|
||||||
byte_length = (self.curve.key_size + 7) // 8
|
|
||||||
return (
|
|
||||||
b"\x04"
|
|
||||||
+ utils.int_to_bytes(self.x, byte_length)
|
|
||||||
+ utils.int_to_bytes(self.y, byte_length)
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_encoded_point(
|
|
||||||
cls, curve: EllipticCurve, data: bytes
|
|
||||||
) -> "EllipticCurvePublicNumbers":
|
|
||||||
if not isinstance(curve, EllipticCurve):
|
|
||||||
raise TypeError("curve must be an EllipticCurve instance")
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
"Support for unsafe construction of public numbers from "
|
|
||||||
"encoded data will be removed in a future version. "
|
|
||||||
"Please use EllipticCurvePublicKey.from_encoded_point",
|
|
||||||
utils.PersistentlyDeprecated2019,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if data.startswith(b"\x04"):
|
|
||||||
# key_size is in bits. Convert to bytes and round up
|
|
||||||
byte_length = (curve.key_size + 7) // 8
|
|
||||||
if len(data) == 2 * byte_length + 1:
|
|
||||||
x = int.from_bytes(data[1 : byte_length + 1], "big")
|
|
||||||
y = int.from_bytes(data[byte_length + 1 :], "big")
|
|
||||||
return cls(x, y, curve)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid elliptic curve point data length")
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported elliptic curve point type")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def curve(self) -> EllipticCurve:
|
|
||||||
return self._curve
|
|
||||||
|
|
||||||
@property
|
|
||||||
def x(self) -> int:
|
|
||||||
return self._x
|
|
||||||
|
|
||||||
@property
|
|
||||||
def y(self) -> int:
|
|
||||||
return self._y
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, EllipticCurvePublicNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return (
|
|
||||||
self.x == other.x
|
|
||||||
and self.y == other.y
|
|
||||||
and self.curve.name == other.curve.name
|
|
||||||
and self.curve.key_size == other.curve.key_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return hash((self.x, self.y, self.curve.name, self.curve.key_size))
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
"<EllipticCurvePublicNumbers(curve={0.curve.name}, x={0.x}, "
|
|
||||||
"y={0.y}>".format(self)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EllipticCurvePrivateNumbers:
|
|
||||||
def __init__(
|
|
||||||
self, private_value: int, public_numbers: EllipticCurvePublicNumbers
|
|
||||||
):
|
|
||||||
if not isinstance(private_value, int):
|
|
||||||
raise TypeError("private_value must be an integer.")
|
|
||||||
|
|
||||||
if not isinstance(public_numbers, EllipticCurvePublicNumbers):
|
|
||||||
raise TypeError(
|
|
||||||
"public_numbers must be an EllipticCurvePublicNumbers "
|
|
||||||
"instance."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._private_value = private_value
|
|
||||||
self._public_numbers = public_numbers
|
|
||||||
|
|
||||||
def private_key(
|
|
||||||
self, backend: typing.Any = None
|
|
||||||
) -> EllipticCurvePrivateKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_elliptic_curve_private_numbers(self)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def private_value(self) -> int:
|
|
||||||
return self._private_value
|
|
||||||
|
|
||||||
@property
|
|
||||||
def public_numbers(self) -> EllipticCurvePublicNumbers:
|
|
||||||
return self._public_numbers
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, EllipticCurvePrivateNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return (
|
|
||||||
self.private_value == other.private_value
|
|
||||||
and self.public_numbers == other.public_numbers
|
|
||||||
)
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return hash((self.private_value, self.public_numbers))
|
|
||||||
|
|
||||||
|
|
||||||
class ECDH:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
_OID_TO_CURVE = {
|
|
||||||
EllipticCurveOID.SECP192R1: SECP192R1,
|
|
||||||
EllipticCurveOID.SECP224R1: SECP224R1,
|
|
||||||
EllipticCurveOID.SECP256K1: SECP256K1,
|
|
||||||
EllipticCurveOID.SECP256R1: SECP256R1,
|
|
||||||
EllipticCurveOID.SECP384R1: SECP384R1,
|
|
||||||
EllipticCurveOID.SECP521R1: SECP521R1,
|
|
||||||
EllipticCurveOID.BRAINPOOLP256R1: BrainpoolP256R1,
|
|
||||||
EllipticCurveOID.BRAINPOOLP384R1: BrainpoolP384R1,
|
|
||||||
EllipticCurveOID.BRAINPOOLP512R1: BrainpoolP512R1,
|
|
||||||
EllipticCurveOID.SECT163K1: SECT163K1,
|
|
||||||
EllipticCurveOID.SECT163R2: SECT163R2,
|
|
||||||
EllipticCurveOID.SECT233K1: SECT233K1,
|
|
||||||
EllipticCurveOID.SECT233R1: SECT233R1,
|
|
||||||
EllipticCurveOID.SECT283K1: SECT283K1,
|
|
||||||
EllipticCurveOID.SECT283R1: SECT283R1,
|
|
||||||
EllipticCurveOID.SECT409K1: SECT409K1,
|
|
||||||
EllipticCurveOID.SECT409R1: SECT409R1,
|
|
||||||
EllipticCurveOID.SECT571K1: SECT571K1,
|
|
||||||
EllipticCurveOID.SECT571R1: SECT571R1,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_curve_for_oid(oid: ObjectIdentifier) -> typing.Type[EllipticCurve]:
|
|
||||||
try:
|
|
||||||
return _OID_TO_CURVE[oid]
|
|
||||||
except KeyError:
|
|
||||||
raise LookupError(
|
|
||||||
"The provided object identifier has no matching elliptic "
|
|
||||||
"curve class"
|
|
||||||
)
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import abc
|
|
||||||
|
|
||||||
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
|
|
||||||
from cryptography.hazmat.primitives import _serialization
|
|
||||||
|
|
||||||
|
|
||||||
_ED25519_KEY_SIZE = 32
|
|
||||||
_ED25519_SIG_SIZE = 64
|
|
||||||
|
|
||||||
|
|
||||||
class Ed25519PublicKey(metaclass=abc.ABCMeta):
|
|
||||||
@classmethod
|
|
||||||
def from_public_bytes(cls, data: bytes) -> "Ed25519PublicKey":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
|
|
||||||
if not backend.ed25519_supported():
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"ed25519 is not supported by this version of OpenSSL.",
|
|
||||||
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return backend.ed25519_load_public_bytes(data)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
The serialized bytes of the public key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def verify(self, signature: bytes, data: bytes) -> None:
|
|
||||||
"""
|
|
||||||
Verify the signature.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Ed25519PrivateKey(metaclass=abc.ABCMeta):
|
|
||||||
@classmethod
|
|
||||||
def generate(cls) -> "Ed25519PrivateKey":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
|
|
||||||
if not backend.ed25519_supported():
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"ed25519 is not supported by this version of OpenSSL.",
|
|
||||||
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return backend.ed25519_generate_key()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_private_bytes(cls, data: bytes) -> "Ed25519PrivateKey":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
|
|
||||||
if not backend.ed25519_supported():
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"ed25519 is not supported by this version of OpenSSL.",
|
|
||||||
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return backend.ed25519_load_private_bytes(data)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_key(self) -> Ed25519PublicKey:
|
|
||||||
"""
|
|
||||||
The Ed25519PublicKey derived from the private key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: _serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
The serialized bytes of the private key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def sign(self, data: bytes) -> bytes:
|
|
||||||
"""
|
|
||||||
Signs the data.
|
|
||||||
"""
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import abc
|
|
||||||
|
|
||||||
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
|
|
||||||
from cryptography.hazmat.primitives import _serialization
|
|
||||||
|
|
||||||
|
|
||||||
class Ed448PublicKey(metaclass=abc.ABCMeta):
|
|
||||||
@classmethod
|
|
||||||
def from_public_bytes(cls, data: bytes) -> "Ed448PublicKey":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
|
|
||||||
if not backend.ed448_supported():
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"ed448 is not supported by this version of OpenSSL.",
|
|
||||||
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return backend.ed448_load_public_bytes(data)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
The serialized bytes of the public key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def verify(self, signature: bytes, data: bytes) -> None:
|
|
||||||
"""
|
|
||||||
Verify the signature.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Ed448PrivateKey(metaclass=abc.ABCMeta):
|
|
||||||
@classmethod
|
|
||||||
def generate(cls) -> "Ed448PrivateKey":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
|
|
||||||
if not backend.ed448_supported():
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"ed448 is not supported by this version of OpenSSL.",
|
|
||||||
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
|
|
||||||
)
|
|
||||||
return backend.ed448_generate_key()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_private_bytes(cls, data: bytes) -> "Ed448PrivateKey":
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend
|
|
||||||
|
|
||||||
if not backend.ed448_supported():
|
|
||||||
raise UnsupportedAlgorithm(
|
|
||||||
"ed448 is not supported by this version of OpenSSL.",
|
|
||||||
_Reasons.UNSUPPORTED_PUBLIC_KEY_ALGORITHM,
|
|
||||||
)
|
|
||||||
|
|
||||||
return backend.ed448_load_private_bytes(data)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_key(self) -> Ed448PublicKey:
|
|
||||||
"""
|
|
||||||
The Ed448PublicKey derived from the private key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def sign(self, data: bytes) -> bytes:
|
|
||||||
"""
|
|
||||||
Signs the data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: _serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
The serialized bytes of the private key.
|
|
||||||
"""
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import hashes
|
|
||||||
from cryptography.hazmat.primitives._asymmetric import (
|
|
||||||
AsymmetricPadding as AsymmetricPadding,
|
|
||||||
)
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
|
|
||||||
|
|
||||||
class PKCS1v15(AsymmetricPadding):
|
|
||||||
name = "EMSA-PKCS1-v1_5"
|
|
||||||
|
|
||||||
|
|
||||||
class _MaxLength:
|
|
||||||
"Sentinel value for `MAX_LENGTH`."
|
|
||||||
|
|
||||||
|
|
||||||
class _Auto:
|
|
||||||
"Sentinel value for `AUTO`."
|
|
||||||
|
|
||||||
|
|
||||||
class _DigestLength:
|
|
||||||
"Sentinel value for `DIGEST_LENGTH`."
|
|
||||||
|
|
||||||
|
|
||||||
class PSS(AsymmetricPadding):
|
|
||||||
MAX_LENGTH = _MaxLength()
|
|
||||||
AUTO = _Auto()
|
|
||||||
DIGEST_LENGTH = _DigestLength()
|
|
||||||
name = "EMSA-PSS"
|
|
||||||
_salt_length: typing.Union[int, _MaxLength, _Auto, _DigestLength]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mgf: "MGF",
|
|
||||||
salt_length: typing.Union[int, _MaxLength, _Auto, _DigestLength],
|
|
||||||
) -> None:
|
|
||||||
self._mgf = mgf
|
|
||||||
|
|
||||||
if not isinstance(
|
|
||||||
salt_length, (int, _MaxLength, _Auto, _DigestLength)
|
|
||||||
):
|
|
||||||
raise TypeError(
|
|
||||||
"salt_length must be an integer, MAX_LENGTH, "
|
|
||||||
"DIGEST_LENGTH, or AUTO"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(salt_length, int) and salt_length < 0:
|
|
||||||
raise ValueError("salt_length must be zero or greater.")
|
|
||||||
|
|
||||||
self._salt_length = salt_length
|
|
||||||
|
|
||||||
|
|
||||||
class OAEP(AsymmetricPadding):
|
|
||||||
name = "EME-OAEP"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mgf: "MGF",
|
|
||||||
algorithm: hashes.HashAlgorithm,
|
|
||||||
label: typing.Optional[bytes],
|
|
||||||
):
|
|
||||||
if not isinstance(algorithm, hashes.HashAlgorithm):
|
|
||||||
raise TypeError("Expected instance of hashes.HashAlgorithm.")
|
|
||||||
|
|
||||||
self._mgf = mgf
|
|
||||||
self._algorithm = algorithm
|
|
||||||
self._label = label
|
|
||||||
|
|
||||||
|
|
||||||
class MGF(metaclass=abc.ABCMeta):
|
|
||||||
_algorithm: hashes.HashAlgorithm
|
|
||||||
|
|
||||||
|
|
||||||
class MGF1(MGF):
|
|
||||||
MAX_LENGTH = _MaxLength()
|
|
||||||
|
|
||||||
def __init__(self, algorithm: hashes.HashAlgorithm):
|
|
||||||
if not isinstance(algorithm, hashes.HashAlgorithm):
|
|
||||||
raise TypeError("Expected instance of hashes.HashAlgorithm.")
|
|
||||||
|
|
||||||
self._algorithm = algorithm
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_max_pss_salt_length(
|
|
||||||
key: typing.Union["rsa.RSAPrivateKey", "rsa.RSAPublicKey"],
|
|
||||||
hash_algorithm: hashes.HashAlgorithm,
|
|
||||||
) -> int:
|
|
||||||
if not isinstance(key, (rsa.RSAPrivateKey, rsa.RSAPublicKey)):
|
|
||||||
raise TypeError("key must be an RSA public or private key")
|
|
||||||
# bit length - 1 per RFC 3447
|
|
||||||
emlen = (key.key_size + 6) // 8
|
|
||||||
salt_length = emlen - hash_algorithm.digest_size - 2
|
|
||||||
assert salt_length >= 0
|
|
||||||
return salt_length
|
|
||||||
@@ -1,425 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import typing
|
|
||||||
from math import gcd
|
|
||||||
|
|
||||||
from cryptography.hazmat.primitives import _serialization, hashes
|
|
||||||
from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import (
|
|
||||||
utils as asym_utils,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RSAPrivateKey(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
|
|
||||||
"""
|
|
||||||
Decrypts the provided ciphertext.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
The bit length of the public modulus.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_key(self) -> "RSAPublicKey":
|
|
||||||
"""
|
|
||||||
The RSAPublicKey associated with this private key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def sign(
|
|
||||||
self,
|
|
||||||
data: bytes,
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Signs the data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_numbers(self) -> "RSAPrivateNumbers":
|
|
||||||
"""
|
|
||||||
Returns an RSAPrivateNumbers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def private_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PrivateFormat,
|
|
||||||
encryption_algorithm: _serialization.KeySerializationEncryption,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Returns the key serialized as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
RSAPrivateKeyWithSerialization = RSAPrivateKey
|
|
||||||
|
|
||||||
|
|
||||||
class RSAPublicKey(metaclass=abc.ABCMeta):
|
|
||||||
@abc.abstractmethod
|
|
||||||
def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes:
|
|
||||||
"""
|
|
||||||
Encrypts the given plaintext.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def key_size(self) -> int:
|
|
||||||
"""
|
|
||||||
The bit length of the public modulus.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_numbers(self) -> "RSAPublicNumbers":
|
|
||||||
"""
|
|
||||||
Returns an RSAPublicNumbers
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def public_bytes(
|
|
||||||
self,
|
|
||||||
encoding: _serialization.Encoding,
|
|
||||||
format: _serialization.PublicFormat,
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Returns the key serialized as bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def verify(
|
|
||||||
self,
|
|
||||||
signature: bytes,
|
|
||||||
data: bytes,
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Verifies the signature of the data.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def recover_data_from_signature(
|
|
||||||
self,
|
|
||||||
signature: bytes,
|
|
||||||
padding: AsymmetricPadding,
|
|
||||||
algorithm: typing.Optional[hashes.HashAlgorithm],
|
|
||||||
) -> bytes:
|
|
||||||
"""
|
|
||||||
Recovers the original data from the signature.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
RSAPublicKeyWithSerialization = RSAPublicKey
|
|
||||||
|
|
||||||
|
|
||||||
def generate_private_key(
|
|
||||||
public_exponent: int,
|
|
||||||
key_size: int,
|
|
||||||
backend: typing.Any = None,
|
|
||||||
) -> RSAPrivateKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import backend as ossl
|
|
||||||
|
|
||||||
_verify_rsa_parameters(public_exponent, key_size)
|
|
||||||
return ossl.generate_rsa_private_key(public_exponent, key_size)
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_rsa_parameters(public_exponent: int, key_size: int) -> None:
|
|
||||||
if public_exponent not in (3, 65537):
|
|
||||||
raise ValueError(
|
|
||||||
"public_exponent must be either 3 (for legacy compatibility) or "
|
|
||||||
"65537. Almost everyone should choose 65537 here!"
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_size < 512:
|
|
||||||
raise ValueError("key_size must be at least 512-bits.")
|
|
||||||
|
|
||||||
|
|
||||||
def _check_private_key_components(
|
|
||||||
p: int,
|
|
||||||
q: int,
|
|
||||||
private_exponent: int,
|
|
||||||
dmp1: int,
|
|
||||||
dmq1: int,
|
|
||||||
iqmp: int,
|
|
||||||
public_exponent: int,
|
|
||||||
modulus: int,
|
|
||||||
) -> None:
|
|
||||||
if modulus < 3:
|
|
||||||
raise ValueError("modulus must be >= 3.")
|
|
||||||
|
|
||||||
if p >= modulus:
|
|
||||||
raise ValueError("p must be < modulus.")
|
|
||||||
|
|
||||||
if q >= modulus:
|
|
||||||
raise ValueError("q must be < modulus.")
|
|
||||||
|
|
||||||
if dmp1 >= modulus:
|
|
||||||
raise ValueError("dmp1 must be < modulus.")
|
|
||||||
|
|
||||||
if dmq1 >= modulus:
|
|
||||||
raise ValueError("dmq1 must be < modulus.")
|
|
||||||
|
|
||||||
if iqmp >= modulus:
|
|
||||||
raise ValueError("iqmp must be < modulus.")
|
|
||||||
|
|
||||||
if private_exponent >= modulus:
|
|
||||||
raise ValueError("private_exponent must be < modulus.")
|
|
||||||
|
|
||||||
if public_exponent < 3 or public_exponent >= modulus:
|
|
||||||
raise ValueError("public_exponent must be >= 3 and < modulus.")
|
|
||||||
|
|
||||||
if public_exponent & 1 == 0:
|
|
||||||
raise ValueError("public_exponent must be odd.")
|
|
||||||
|
|
||||||
if dmp1 & 1 == 0:
|
|
||||||
raise ValueError("dmp1 must be odd.")
|
|
||||||
|
|
||||||
if dmq1 & 1 == 0:
|
|
||||||
raise ValueError("dmq1 must be odd.")
|
|
||||||
|
|
||||||
if p * q != modulus:
|
|
||||||
raise ValueError("p*q must equal modulus.")
|
|
||||||
|
|
||||||
|
|
||||||
def _check_public_key_components(e: int, n: int) -> None:
|
|
||||||
if n < 3:
|
|
||||||
raise ValueError("n must be >= 3.")
|
|
||||||
|
|
||||||
if e < 3 or e >= n:
|
|
||||||
raise ValueError("e must be >= 3 and < n.")
|
|
||||||
|
|
||||||
if e & 1 == 0:
|
|
||||||
raise ValueError("e must be odd.")
|
|
||||||
|
|
||||||
|
|
||||||
def _modinv(e: int, m: int) -> int:
|
|
||||||
"""
|
|
||||||
Modular Multiplicative Inverse. Returns x such that: (x*e) mod m == 1
|
|
||||||
"""
|
|
||||||
x1, x2 = 1, 0
|
|
||||||
a, b = e, m
|
|
||||||
while b > 0:
|
|
||||||
q, r = divmod(a, b)
|
|
||||||
xn = x1 - q * x2
|
|
||||||
a, b, x1, x2 = b, r, x2, xn
|
|
||||||
return x1 % m
|
|
||||||
|
|
||||||
|
|
||||||
def rsa_crt_iqmp(p: int, q: int) -> int:
|
|
||||||
"""
|
|
||||||
Compute the CRT (q ** -1) % p value from RSA primes p and q.
|
|
||||||
"""
|
|
||||||
return _modinv(q, p)
|
|
||||||
|
|
||||||
|
|
||||||
def rsa_crt_dmp1(private_exponent: int, p: int) -> int:
|
|
||||||
"""
|
|
||||||
Compute the CRT private_exponent % (p - 1) value from the RSA
|
|
||||||
private_exponent (d) and p.
|
|
||||||
"""
|
|
||||||
return private_exponent % (p - 1)
|
|
||||||
|
|
||||||
|
|
||||||
def rsa_crt_dmq1(private_exponent: int, q: int) -> int:
|
|
||||||
"""
|
|
||||||
Compute the CRT private_exponent % (q - 1) value from the RSA
|
|
||||||
private_exponent (d) and q.
|
|
||||||
"""
|
|
||||||
return private_exponent % (q - 1)
|
|
||||||
|
|
||||||
|
|
||||||
# Controls the number of iterations rsa_recover_prime_factors will perform
|
|
||||||
# to obtain the prime factors. Each iteration increments by 2 so the actual
|
|
||||||
# maximum attempts is half this number.
|
|
||||||
_MAX_RECOVERY_ATTEMPTS = 1000
|
|
||||||
|
|
||||||
|
|
||||||
def rsa_recover_prime_factors(
|
|
||||||
n: int, e: int, d: int
|
|
||||||
) -> typing.Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Compute factors p and q from the private exponent d. We assume that n has
|
|
||||||
no more than two factors. This function is adapted from code in PyCrypto.
|
|
||||||
"""
|
|
||||||
# See 8.2.2(i) in Handbook of Applied Cryptography.
|
|
||||||
ktot = d * e - 1
|
|
||||||
# The quantity d*e-1 is a multiple of phi(n), even,
|
|
||||||
# and can be represented as t*2^s.
|
|
||||||
t = ktot
|
|
||||||
while t % 2 == 0:
|
|
||||||
t = t // 2
|
|
||||||
# Cycle through all multiplicative inverses in Zn.
|
|
||||||
# The algorithm is non-deterministic, but there is a 50% chance
|
|
||||||
# any candidate a leads to successful factoring.
|
|
||||||
# See "Digitalized Signatures and Public Key Functions as Intractable
|
|
||||||
# as Factorization", M. Rabin, 1979
|
|
||||||
spotted = False
|
|
||||||
a = 2
|
|
||||||
while not spotted and a < _MAX_RECOVERY_ATTEMPTS:
|
|
||||||
k = t
|
|
||||||
# Cycle through all values a^{t*2^i}=a^k
|
|
||||||
while k < ktot:
|
|
||||||
cand = pow(a, k, n)
|
|
||||||
# Check if a^k is a non-trivial root of unity (mod n)
|
|
||||||
if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1:
|
|
||||||
# We have found a number such that (cand-1)(cand+1)=0 (mod n).
|
|
||||||
# Either of the terms divides n.
|
|
||||||
p = gcd(cand + 1, n)
|
|
||||||
spotted = True
|
|
||||||
break
|
|
||||||
k *= 2
|
|
||||||
# This value was not any good... let's try another!
|
|
||||||
a += 2
|
|
||||||
if not spotted:
|
|
||||||
raise ValueError("Unable to compute factors p and q from exponent d.")
|
|
||||||
# Found !
|
|
||||||
q, r = divmod(n, p)
|
|
||||||
assert r == 0
|
|
||||||
p, q = sorted((p, q), reverse=True)
|
|
||||||
return (p, q)
|
|
||||||
|
|
||||||
|
|
||||||
class RSAPrivateNumbers:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
p: int,
|
|
||||||
q: int,
|
|
||||||
d: int,
|
|
||||||
dmp1: int,
|
|
||||||
dmq1: int,
|
|
||||||
iqmp: int,
|
|
||||||
public_numbers: "RSAPublicNumbers",
|
|
||||||
):
|
|
||||||
if (
|
|
||||||
not isinstance(p, int)
|
|
||||||
or not isinstance(q, int)
|
|
||||||
or not isinstance(d, int)
|
|
||||||
or not isinstance(dmp1, int)
|
|
||||||
or not isinstance(dmq1, int)
|
|
||||||
or not isinstance(iqmp, int)
|
|
||||||
):
|
|
||||||
raise TypeError(
|
|
||||||
"RSAPrivateNumbers p, q, d, dmp1, dmq1, iqmp arguments must"
|
|
||||||
" all be an integers."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(public_numbers, RSAPublicNumbers):
|
|
||||||
raise TypeError(
|
|
||||||
"RSAPrivateNumbers public_numbers must be an RSAPublicNumbers"
|
|
||||||
" instance."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._p = p
|
|
||||||
self._q = q
|
|
||||||
self._d = d
|
|
||||||
self._dmp1 = dmp1
|
|
||||||
self._dmq1 = dmq1
|
|
||||||
self._iqmp = iqmp
|
|
||||||
self._public_numbers = public_numbers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def p(self) -> int:
|
|
||||||
return self._p
|
|
||||||
|
|
||||||
@property
|
|
||||||
def q(self) -> int:
|
|
||||||
return self._q
|
|
||||||
|
|
||||||
@property
|
|
||||||
def d(self) -> int:
|
|
||||||
return self._d
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dmp1(self) -> int:
|
|
||||||
return self._dmp1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dmq1(self) -> int:
|
|
||||||
return self._dmq1
|
|
||||||
|
|
||||||
@property
|
|
||||||
def iqmp(self) -> int:
|
|
||||||
return self._iqmp
|
|
||||||
|
|
||||||
@property
|
|
||||||
def public_numbers(self) -> "RSAPublicNumbers":
|
|
||||||
return self._public_numbers
|
|
||||||
|
|
||||||
def private_key(self, backend: typing.Any = None) -> RSAPrivateKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_rsa_private_numbers(self)
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, RSAPrivateNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return (
|
|
||||||
self.p == other.p
|
|
||||||
and self.q == other.q
|
|
||||||
and self.d == other.d
|
|
||||||
and self.dmp1 == other.dmp1
|
|
||||||
and self.dmq1 == other.dmq1
|
|
||||||
and self.iqmp == other.iqmp
|
|
||||||
and self.public_numbers == other.public_numbers
|
|
||||||
)
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return hash(
|
|
||||||
(
|
|
||||||
self.p,
|
|
||||||
self.q,
|
|
||||||
self.d,
|
|
||||||
self.dmp1,
|
|
||||||
self.dmq1,
|
|
||||||
self.iqmp,
|
|
||||||
self.public_numbers,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RSAPublicNumbers:
|
|
||||||
def __init__(self, e: int, n: int):
|
|
||||||
if not isinstance(e, int) or not isinstance(n, int):
|
|
||||||
raise TypeError("RSAPublicNumbers arguments must be integers.")
|
|
||||||
|
|
||||||
self._e = e
|
|
||||||
self._n = n
|
|
||||||
|
|
||||||
@property
|
|
||||||
def e(self) -> int:
|
|
||||||
return self._e
|
|
||||||
|
|
||||||
@property
|
|
||||||
def n(self) -> int:
|
|
||||||
return self._n
|
|
||||||
|
|
||||||
def public_key(self, backend: typing.Any = None) -> RSAPublicKey:
|
|
||||||
from cryptography.hazmat.backends.openssl.backend import (
|
|
||||||
backend as ossl,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ossl.load_rsa_public_numbers(self)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "<RSAPublicNumbers(e={0.e}, n={0.n})>".format(self)
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, RSAPublicNumbers):
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
return self.e == other.e and self.n == other.n
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return hash((self.e, self.n))
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
# This file is dual licensed under the terms of the Apache License, Version
|
|
||||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
|
||||||
# for complete details.
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import (
|
|
||||||
dh,
|
|
||||||
dsa,
|
|
||||||
ec,
|
|
||||||
ed25519,
|
|
||||||
ed448,
|
|
||||||
rsa,
|
|
||||||
x25519,
|
|
||||||
x448,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Every asymmetric key type
|
|
||||||
PUBLIC_KEY_TYPES = typing.Union[
|
|
||||||
dh.DHPublicKey,
|
|
||||||
dsa.DSAPublicKey,
|
|
||||||
rsa.RSAPublicKey,
|
|
||||||
ec.EllipticCurvePublicKey,
|
|
||||||
ed25519.Ed25519PublicKey,
|
|
||||||
ed448.Ed448PublicKey,
|
|
||||||
x25519.X25519PublicKey,
|
|
||||||
x448.X448PublicKey,
|
|
||||||
]
|
|
||||||
# Every asymmetric key type
|
|
||||||
PRIVATE_KEY_TYPES = typing.Union[
|
|
||||||
dh.DHPrivateKey,
|
|
||||||
ed25519.Ed25519PrivateKey,
|
|
||||||
ed448.Ed448PrivateKey,
|
|
||||||
rsa.RSAPrivateKey,
|
|
||||||
dsa.DSAPrivateKey,
|
|
||||||
ec.EllipticCurvePrivateKey,
|
|
||||||
x25519.X25519PrivateKey,
|
|
||||||
x448.X448PrivateKey,
|
|
||||||
]
|
|
||||||
# Just the key types we allow to be used for x509 signing. This mirrors
|
|
||||||
# the certificate public key types
|
|
||||||
CERTIFICATE_PRIVATE_KEY_TYPES = typing.Union[
|
|
||||||
ed25519.Ed25519PrivateKey,
|
|
||||||
ed448.Ed448PrivateKey,
|
|
||||||
rsa.RSAPrivateKey,
|
|
||||||
dsa.DSAPrivateKey,
|
|
||||||
ec.EllipticCurvePrivateKey,
|
|
||||||
]
|
|
||||||
# Just the key types we allow to be used for x509 signing. This mirrors
|
|
||||||
# the certificate private key types
|
|
||||||
CERTIFICATE_ISSUER_PUBLIC_KEY_TYPES = typing.Union[
|
|
||||||
dsa.DSAPublicKey,
|
|
||||||
rsa.RSAPublicKey,
|
|
||||||
ec.EllipticCurvePublicKey,
|
|
||||||
ed25519.Ed25519PublicKey,
|
|
||||||
ed448.Ed448PublicKey,
|
|
||||||
]
|
|
||||||
# This type removes DHPublicKey. x448/x25519 can be a public key
|
|
||||||
# but cannot be used in signing so they are allowed here.
|
|
||||||
CERTIFICATE_PUBLIC_KEY_TYPES = typing.Union[
|
|
||||||
dsa.DSAPublicKey,
|
|
||||||
rsa.RSAPublicKey,
|
|
||||||
ec.EllipticCurvePublicKey,
|
|
||||||
ed25519.Ed25519PublicKey,
|
|
||||||
ed448.Ed448PublicKey,
|
|
||||||
x25519.X25519PublicKey,
|
|
||||||
x448.X448PublicKey,
|
|
||||||
]
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user