Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def build_using_shared_lib(
deps: list[str],
build_dir: str,
extra_compile_args: list[str],
extra_include_dirs: list[str],
) -> list[Extension]:
"""Produce the list of extension modules when a shared library is needed.

Expand All @@ -373,7 +374,7 @@ def build_using_shared_lib(
get_extension()(
shared_lib_name(group_name),
sources=cfiles,
include_dirs=[include_dir(), build_dir],
include_dirs=[include_dir(), build_dir] + extra_include_dirs,
depends=deps,
extra_compile_args=extra_compile_args,
)
Expand All @@ -399,7 +400,10 @@ def build_using_shared_lib(


def build_single_module(
sources: list[BuildSource], cfiles: list[str], extra_compile_args: list[str]
sources: list[BuildSource],
cfiles: list[str],
extra_compile_args: list[str],
extra_include_dirs: list[str],
) -> list[Extension]:
"""Produce the list of extension modules for a standalone extension.

Expand All @@ -409,7 +413,7 @@ def build_single_module(
get_extension()(
sources[0].module,
sources=cfiles,
include_dirs=[include_dir()],
include_dirs=[include_dir()] + extra_include_dirs,
extra_compile_args=extra_compile_args,
)
]
Expand Down Expand Up @@ -513,7 +517,9 @@ def mypyc_build(
*,
separate: bool | list[tuple[list[str], str | None]] = False,
only_compile_paths: Iterable[str] | None = None,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
skip_cgen_input: (
tuple[list[list[tuple[str, str]]], list[tuple[str, list[str], bool]]] | None
) = None,
always_use_shared_lib: bool = False,
) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]], list[SourceDep]]:
"""Do the front and middle end of mypyc building, producing and writing out C source."""
Expand Down Expand Up @@ -547,7 +553,10 @@ def mypyc_build(
write_file(os.path.join(compiler_options.target_dir, "ops.txt"), ops_text)
else:
group_cfiles = skip_cgen_input[0]
source_deps = [SourceDep(d) for d in skip_cgen_input[1]]
source_deps = [
SourceDep(path, include_dirs=dirs, internal=internal)
for (path, dirs, internal) in skip_cgen_input[1]
]

# Write out the generated C and collect the files for each group
# Should this be here??
Expand Down Expand Up @@ -664,7 +673,9 @@ def mypycify(
strip_asserts: bool = False,
multi_file: bool = False,
separate: bool | list[tuple[list[str], str | None]] = False,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
skip_cgen_input: (
tuple[list[list[tuple[str, str]]], list[tuple[str, list[str], bool]]] | None
) = None,
target_dir: str | None = None,
include_runtime_files: bool | None = None,
strict_dunder_typing: bool = False,
Expand Down Expand Up @@ -781,12 +792,19 @@ def mypycify(
# runtime library in. Otherwise it just gets #included to save on
# compiler invocations.
shared_cfilenames = []
include_dirs = set()
if not compiler_options.include_runtime_files:
# Collect all files to copy: runtime files + conditional source files
files_to_copy = list(RUNTIME_C_FILES)
for source_dep in source_deps:
files_to_copy.append(source_dep.path)
files_to_copy.append(source_dep.get_header())
include_dirs.update(source_dep.include_dirs)

if compiler_options.depends_on_librt_internal:
files_to_copy.append("internal/librt_internal_api.h")
files_to_copy.append("internal/librt_internal_api.c")
include_dirs.add("internal")

# Copy all files
for name in files_to_copy:
Expand All @@ -797,6 +815,7 @@ def mypycify(
shared_cfilenames.append(rt_file)

extensions = []
extra_include_dirs = [os.path.join(include_dir(), dir) for dir in include_dirs]
for (group_sources, lib_name), (cfilenames, deps) in zip(groups, group_cfilenames):
if lib_name:
extensions.extend(
Expand All @@ -807,11 +826,14 @@ def mypycify(
deps,
build_dir,
cflags,
extra_include_dirs,
)
)
else:
extensions.extend(
build_single_module(group_sources, cfilenames + shared_cfilenames, cflags)
build_single_module(
group_sources, cfilenames + shared_cfilenames, cflags, extra_include_dirs
)
)

if install_librt:
Expand Down
47 changes: 30 additions & 17 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,13 +436,23 @@ def load_scc_from_cache(
return modules


def collect_source_dependencies(modules: dict[str, ModuleIR]) -> set[SourceDep]:
"""Collect all SourceDep dependencies from all modules."""
def collect_source_dependencies(
modules: dict[str, ModuleIR], *, internal: bool = True
) -> set[SourceDep]:
"""Collect all SourceDep dependencies from all modules.

If internal is set to False, returns only the dependencies that can be exported to C extensions
dependent on the one currently being compiled.
"""
source_deps: set[SourceDep] = set()
for module in modules.values():
for dep in module.dependencies:
if isinstance(dep, SourceDep):
source_deps.add(dep)
if internal == dep.internal:
source_deps.add(dep)
else:
capsule_dep = dep.internal_dep() if internal else dep.external_dep()
source_deps.add(capsule_dep)
return source_deps


Expand Down Expand Up @@ -585,6 +595,8 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
source_deps = collect_source_dependencies(self.modules)
for source_dep in sorted(source_deps, key=lambda d: d.path):
base_emitter.emit_line(f'#include "{source_dep.path}"')
if self.compiler_options.depends_on_librt_internal:
base_emitter.emit_line('#include "internal/librt_internal_api.c"')
base_emitter.emit_line(f'#include "__native{self.short_group_suffix}.h"')
base_emitter.emit_line(f'#include "__native_internal{self.short_group_suffix}.h"')
emitter = base_emitter
Expand Down Expand Up @@ -634,26 +646,27 @@ def generate_c_for_modules(self) -> list[tuple[str, str]]:
ext_declarations.emit_line(f"#define MYPYC_NATIVE{self.group_suffix}_H")
ext_declarations.emit_line("#include <Python.h>")
ext_declarations.emit_line("#include <CPy.h>")
if self.compiler_options.depends_on_librt_internal:
ext_declarations.emit_line("#include <internal/librt_internal.h>")
if any(LIBRT_BASE64 in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <base64/librt_base64.h>")
if any(LIBRT_STRINGS in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <strings/librt_strings.h>")
if any(LIBRT_TIME in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <time/librt_time.h>")
if any(LIBRT_VECS in mod.dependencies for mod in self.modules.values()):
ext_declarations.emit_line("#include <vecs/librt_vecs.h>")
# Include headers for conditional source files
source_deps = collect_source_dependencies(self.modules)
for source_dep in sorted(source_deps, key=lambda d: d.path):
ext_declarations.emit_line(f'#include "{source_dep.get_header()}"')

def emit_dep_headers(decls: Emitter, internal: bool) -> None:
suffix = "_api" if internal else ""
if self.compiler_options.depends_on_librt_internal:
decls.emit_line(f'#include "internal/librt_internal{suffix}.h"')
# Include headers for conditional source files
source_deps = collect_source_dependencies(self.modules, internal=internal)
for source_dep in sorted(source_deps, key=lambda d: d.path):
decls.emit_line(f'#include "{source_dep.get_header()}"')

emit_dep_headers(ext_declarations, False)

declarations = Emitter(self.context)
declarations.emit_line(f"#ifndef MYPYC_LIBRT_INTERNAL{self.group_suffix}_H")
declarations.emit_line(f"#define MYPYC_LIBRT_INTERNAL{self.group_suffix}_H")
declarations.emit_line("#include <Python.h>")
declarations.emit_line("#include <CPy.h>")

if not self.compiler_options.include_runtime_files:
emit_dep_headers(declarations, True)

declarations.emit_line(f'#include "__native{self.short_group_suffix}.h"')
declarations.emit_line()
declarations.emit_line("int CPyGlobalsInit(void);")
Expand Down
35 changes: 34 additions & 1 deletion mypyc/ir/deps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Final


Expand All @@ -17,17 +19,48 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
return hash(("Capsule", self.name))

def internal_dep(self) -> SourceDep:
"""Internal source dependency of the capsule that should only be included in the C extensions
that depend on the capsule, eg. by importing a type or function from the capsule.
"""
module = self.name.split(".")[-1]
return SourceDep(f"{module}/librt_{module}_api.c", include_dirs=[module])

# TODO: This SourceDep is really only used for its associated header so it would make more sense
# to add a separate type. Alternatively, see if this can be removed altogether if we move the
# definitions that depend on this header from the external header of the C extension.
def external_dep(self) -> SourceDep:
"""External source dependency of the capsule that may be included in external headers of C
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was pretty confusing. Is this only used for the header? The SourceDep however has a reference to a .c file, which is quite different. I think it could be cleaner to separate e.g. librt_strings.h and librt_strings.c dependencies, since their use cases are quite different. This can happen in a follow-up PR, if you add a TODO here.

extensions that depend on the capsule.

The external headers of the C extensions are included by other C extensions that don't
necessarily import the capsule. However, they may need type definitions from the capsule
for types that are used in the exports table of the included C extensions.

Only the external header should be included in this case because if the other C extension
doesn't import the capsule, it also doesn't include the definition for its API table and
including the internal header would result in undefined symbols.
"""
module = self.name.split(".")[-1]
return SourceDep(f"{module}/librt_{module}.c", include_dirs=[module], internal=False)


class SourceDep:
"""Defines a C source file that a primitive may require.

Each source file must also have a corresponding .h file (replace .c with .h)
that gets implicitly #included if the source is used.
include_dirs are passed to the C compiler when the file is compiled as a
shared library separate from the C extension.
"""

def __init__(self, path: str) -> None:
def __init__(
self, path: str, *, include_dirs: list[str] | None = None, internal: bool = True
) -> None:
# Relative path from mypyc/lib-rt, e.g. 'bytes_extra_ops.c'
self.path: Final = path
self.include_dirs: Final = include_dirs or []
self.internal: Final = internal

def __repr__(self) -> str:
return f"SourceDep(path={self.path!r})"
Expand Down
16 changes: 14 additions & 2 deletions mypyc/ir/module_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ def serialize(self) -> JsonDict:
if isinstance(dep, Capsule):
serialized_deps.append({"type": "Capsule", "name": dep.name})
elif isinstance(dep, SourceDep):
serialized_deps.append({"type": "SourceDep", "path": dep.path})
source_dep: JsonDict = {
"type": "SourceDep",
"path": dep.path,
"include_dirs": dep.include_dirs,
"internal": dep.internal,
}
serialized_deps.append(source_dep)

return {
"fullname": self.fullname,
Expand Down Expand Up @@ -69,7 +75,13 @@ def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ModuleIR:
if dep_dict["type"] == "Capsule":
deps.add(Capsule(dep_dict["name"]))
elif dep_dict["type"] == "SourceDep":
deps.add(SourceDep(dep_dict["path"]))
deps.add(
SourceDep(
dep_dict["path"],
include_dirs=dep_dict["include_dirs"],
internal=dep_dict["internal"],
)
)
module.dependencies = deps

return module
Expand Down
40 changes: 0 additions & 40 deletions mypyc/lib-rt/base64/librt_base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,4 @@
#define LIBRT_BASE64_API_VERSION 2
#define LIBRT_BASE64_API_LEN 4

static void *LibRTBase64_API[LIBRT_BASE64_API_LEN];

#define LibRTBase64_ABIVersion (*(int (*)(void)) LibRTBase64_API[0])
#define LibRTBase64_APIVersion (*(int (*)(void)) LibRTBase64_API[1])
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[2])
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[3])

static int
import_librt_base64(void)
{
PyObject *mod = PyImport_ImportModule("librt.base64");
if (mod == NULL)
return -1;
Py_DECREF(mod); // we import just for the side effect of making the below work.
void *capsule = PyCapsule_Import("librt.base64._C_API", 0);
if (capsule == NULL)
return -1;
memcpy(LibRTBase64_API, capsule, sizeof(LibRTBase64_API));
if (LibRTBase64_ABIVersion() != LIBRT_BASE64_ABI_VERSION) {
char err[128];
snprintf(err, sizeof(err), "ABI version conflict for librt.base64, expected %d, found %d",
LIBRT_BASE64_ABI_VERSION,
LibRTBase64_ABIVersion()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
if (LibRTBase64_APIVersion() < LIBRT_BASE64_API_VERSION) {
char err[128];
snprintf(err, sizeof(err),
"API version conflict for librt.base64, expected %d or newer, found %d (hint: upgrade librt)",
LIBRT_BASE64_API_VERSION,
LibRTBase64_APIVersion()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
return 0;
}

#endif // LIBRT_BASE64_H
36 changes: 36 additions & 0 deletions mypyc/lib-rt/base64/librt_base64_api.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "librt_base64_api.h"

void *LibRTBase64_API[LIBRT_BASE64_API_LEN] = {0};

int
import_librt_base64(void)
{
PyObject *mod = PyImport_ImportModule("librt.base64");
if (mod == NULL)
return -1;
Py_DECREF(mod); // we import just for the side effect of making the below work.
void *capsule = PyCapsule_Import("librt.base64._C_API", 0);
if (capsule == NULL)
return -1;
memcpy(LibRTBase64_API, capsule, sizeof(LibRTBase64_API));
if (LibRTBase64_ABIVersion() != LIBRT_BASE64_ABI_VERSION) {
char err[128];
snprintf(err, sizeof(err), "ABI version conflict for librt.base64, expected %d, found %d",
LIBRT_BASE64_ABI_VERSION,
LibRTBase64_ABIVersion()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
if (LibRTBase64_APIVersion() < LIBRT_BASE64_API_VERSION) {
char err[128];
snprintf(err, sizeof(err),
"API version conflict for librt.base64, expected %d or newer, found %d (hint: upgrade librt)",
LIBRT_BASE64_API_VERSION,
LibRTBase64_APIVersion()
);
PyErr_SetString(PyExc_ValueError, err);
return -1;
}
return 0;
}
15 changes: 15 additions & 0 deletions mypyc/lib-rt/base64/librt_base64_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LIBRT_BASE64_API_H
#define LIBRT_BASE64_API_H

#include "librt_base64.h"

extern void *LibRTBase64_API[LIBRT_BASE64_API_LEN];

#define LibRTBase64_ABIVersion (*(int (*)(void)) LibRTBase64_API[0])
#define LibRTBase64_APIVersion (*(int (*)(void)) LibRTBase64_API[1])
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[2])
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[3])

int import_librt_base64(void);

#endif // LIBRT_BASE64_API_H
2 changes: 1 addition & 1 deletion mypyc/lib-rt/byteswriter_extra_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <Python.h>

#include "mypyc_util.h"
#include "strings/librt_strings.h"
#include "strings/librt_strings_api.h"
#include "strings/librt_strings_common.h"

// BytesWriter: Length and capacity
Expand Down
Loading
Loading