Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions src/io4dolfinx/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import logging
import typing
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -51,6 +52,8 @@
"write_attributes",
]

logger = logging.getLogger(__name__)


def write_attributes(
filename: Path | str,
Expand All @@ -70,6 +73,8 @@ def write_attributes(
backend_args: Arguments for backend, for instance file type.
backend: What backend to use for writing.
"""
logger.debug(f"Writing attributes to {filename} for attribute {name}")
logger.debug(f"Using {backend} backend with arguments {backend_args} to write attributes")
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
backend_cls.write_attributes(filename, comm, name, attributes, backend_args)
Expand All @@ -93,6 +98,8 @@ def read_attributes(
Returns:
The attributes
"""
logger.debug(f"Reading attributes from {filename} for attribute {name}")
logger.debug(f"Using {backend} backend with arguments {backend_args} to read attributes")
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
return backend_cls.read_attributes(filename, comm, name, backend_args)
Expand All @@ -117,6 +124,8 @@ def read_timestamps(
Returns:
The time-stamps
"""
logger.debug(f"Reading time-stamps from {filename} for function {function_name}")
logger.debug(f"Using {backend} backend with arguments {backend_args} to read time-stamps")
check_file_exists(filename)
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
Expand Down Expand Up @@ -146,6 +155,8 @@ def write_meshtags(
backend_args: Option to IO backend.
backend: IO backend
"""
logger.debug(f"Writing meshtags to {filename} for meshtag {meshtag_name or meshtags.name}")
logger.debug(f"Using {backend} backend with arguments {backend_args} to write meshtags")

# Extract data from meshtags (convert to global geometry node indices for each entity)
tag_entities = meshtags.indices
Expand Down Expand Up @@ -212,6 +223,8 @@ def read_meshtags(
Returns:
The meshtags
"""
logger.debug(f"Reading meshtags from {filename} for meshtag {meshtag_name}")
logger.debug(f"Using {backend} backend with arguments {backend_args} to read meshtags")
check_file_exists(filename)
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
Expand Down Expand Up @@ -249,14 +262,23 @@ def read_function(
time: Time-stamp associated with checkpoint
name: If not provided, `u.name` is used to search through the input file for the function
"""
logger.debug(
f"Reading function checkpoint from {filename} for function {name or u.name} at time {time}"
)
logger.debug(
f"Using {backend} backend with arguments {backend_args} to read function checkpoint"
)
check_file_exists(filename)

mesh = u.function_space.mesh
comm = mesh.comm
if name is None:
name = u.name

# ----------------------Step 1---------------------------------
check_file_exists(filename)
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)

# Compute index of input cells and get cell permutation
num_owned_cells = mesh.topology.index_map(mesh.topology.dim).size_local
input_cells = mesh.topology.original_cell_index[:num_owned_cells]
Expand All @@ -265,7 +287,6 @@ def read_function(

# Compute mesh->input communicator
# 1.1 Compute mesh->input communicator
backend_cls = get_backend(backend)
owners: npt.NDArray[np.int32]
if backend_cls.read_mode == ReadMode.serial:
owners = np.zeros(input_cells, dtype=np.int32)
Expand All @@ -278,12 +299,6 @@ def read_function(
# Send and receive global cell index and cell perm
inc_cells, inc_perms = send_and_recv_cell_perm(input_cells, cell_perm, owners, mesh.comm)

# -------------------Step 3-----------------------------------
# Read dofmap from file and compute dof owners
check_file_exists(filename)
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)

input_dofmap = backend_cls.read_dofmap(filename, comm, name, backend_args)

# Compute owner of dofs in dofmap
Expand Down Expand Up @@ -403,6 +418,11 @@ def read_mesh(
Returns:
The distributed mesh
"""
logger.debug(f"Reading mesh from {filename}")
logger.debug(
f"Using {backend} backend with arguments {backend_args}, "
f"time {time} and read_from_partition {read_from_partition}"
)
# Read in data in a distributed fashin
check_file_exists(filename)
backend_cls = get_backend(backend)
Expand Down Expand Up @@ -480,8 +500,13 @@ def write_mesh(

store_partition_info: Store mesh partitioning (including ghosting) to file
"""
logger.debug(f"Writing mesh to {filename}")
logger.debug(f"Preparing mesh data for storage storing partition info: {store_partition_info}")
mesh_data = prepare_meshdata_for_storage(mesh=mesh, store_partition_info=store_partition_info)

logger.debug(
f"Write mesh using {backend} backend, with arguments {backend_args}, "
f"mode {mode} and time {time}"
)
_internal_mesh_writer(
filename,
mesh.comm,
Expand Down Expand Up @@ -514,6 +539,13 @@ def write_function(
backend_args: Arguments to the IO backend.
backend: The backend to use
"""
logger.debug(
f"Writing function checkpoint to {filename} for function {name or u.name} at time {time}"
)
logger.debug(
f"Extracting data from function and dofmap for storage using {backend} "
f"backend with arguments {backend_args}"
)
dofmap = u.function_space.dofmap
values = u.x.array
mesh = u.function_space.mesh
Expand Down Expand Up @@ -587,6 +619,9 @@ def read_function_names(
Returns:
A list of function names.
"""
logger.debug(f"Reading function names from {filename}")
logger.debug(f"Using {backend} backend with arguments {backend_args} to read function names")
check_file_exists(filename)
backend_cls = get_backend(backend)
return backend_cls.read_function_names(filename, comm, backend_args=backend_args)

Expand All @@ -610,6 +645,7 @@ def write_point_data(
backend_args: The backend arguments
backend: Which backend to use.
"""
logger.debug(f"Writing point data to {filename} for function {u.name} at time {time}")
V = create_geometry_function_space(u.function_space.mesh, int(np.prod(u.ufl_shape)))
v_out = dolfinx.fem.Function(V, name=u.name, dtype=u.x.array.dtype)
v_out.interpolate(u)
Expand All @@ -621,6 +657,9 @@ def write_point_data(
ad = ArrayData(
name=v_out.name, values=data, global_shape=data_shape, local_range=local_range, type="Point"
)
logger.debug(
f"Using {backend} backend with arguments {backend_args} and mode {mode} to write point data"
)
backend_cls = get_backend(backend)
return backend_cls.write_data(
filename, comm=comm, mode=mode, time=time, array_data=ad, backend_args=backend_args
Expand All @@ -645,6 +684,7 @@ def write_cell_data(
mode: Append or write
backend_args: The backend arguments
"""
logger.debug(f"Writing cell data to {filename} for function {u.name} at time {time}")
V = dolfinx.fem.functionspace(u.function_space.mesh, ("DG", 0, u.ufl_shape))
v_out = dolfinx.fem.Function(V, name=u.name, dtype=u.x.array.dtype)
v_out.interpolate(u)
Expand All @@ -654,11 +694,14 @@ def write_cell_data(
num_dofs_local = V.dofmap.index_map.size_local
data = v_out.x.array.reshape(-1, V.dofmap.index_map_bs)[:num_dofs_local]

backend_cls = get_backend(backend)
ad = ArrayData(
name=v_out.name, values=data, global_shape=data_shape, local_range=local_range, type="Cell"
)
logger.debug(
f"Using {backend} backend with arguments {backend_args} and mode {mode} to write cell data"
)
backend_cls = get_backend(backend)

return backend_cls.write_data(
filename, comm=comm, mode=mode, time=time, array_data=ad, backend_args=backend_args
)
8 changes: 8 additions & 0 deletions src/io4dolfinx/original_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import logging
import typing
from pathlib import Path

Expand All @@ -27,6 +28,7 @@
)

__all__ = ["write_function_on_input_mesh", "write_mesh_input_order"]
logger = logging.getLogger(__name__)


def create_original_mesh_data(mesh: dolfinx.mesh.Mesh) -> MeshData:
Expand Down Expand Up @@ -361,6 +363,10 @@ def write_function_on_input_mesh(
backend_args: Arguments to backend
backend: Choice of backend module
"""
logger.debug(
f"Writing function on input mesh to {filename} at time {time} with name {name or u.name}"
)
logger.debug(f"Using backend {backend} with arguments {backend_args} and mode {mode}")
mesh = u.function_space.mesh
function_data = create_function_data_on_original_mesh(u, name)
fname = Path(filename)
Expand Down Expand Up @@ -400,6 +406,8 @@ def write_mesh_input_order(
backend_args: Arguments to backend
backend: Choice of backend module
"""
logger.debug(f"Writing mesh in input order to {filename} at time {time}")
logger.debug(f"Using backend {backend} with arguments {backend_args} and mode {mode}")
mesh_data = create_original_mesh_data(mesh)
fname = Path(filename)

Expand Down
9 changes: 8 additions & 1 deletion src/io4dolfinx/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import logging
import pathlib
import typing
from pathlib import Path
Expand All @@ -30,6 +31,7 @@
)

__all__ = ["read_mesh_from_legacy_h5", "read_function_from_legacy_h5", "read_point_data"]
logger = logging.getLogger(__name__)


def map_dofmap(dofmap: dolfinx.graph.AdjacencyList, bs: int) -> npt.NDArray[np.int64]:
Expand Down Expand Up @@ -164,6 +166,8 @@ def read_mesh_from_legacy_h5(
max_facet_to_cell_links: Maximum number of cells a facet
can be connected to.
"""
logger.debug(f"Reading mesh from {filename} at group {group}")
logger.debug(f"Using backend {backend} with max_facet_to_cell_links {max_facet_to_cell_links}")
# Make sure we use the HDF5File and check that the file is present
check_file_exists(filename)

Expand Down Expand Up @@ -240,7 +244,8 @@ def read_function_from_legacy_h5(
the function is saved as a regular function (i.e with `HDF5File.write`)
backend: The IO backend
"""

logger.debug(f"Reading function from {filename} at group {group}")
logger.debug(f"Using backend {backend} with group {group} and step {step}")
# Make sure we use the HDF5File and check that the file is present
filename = pathlib.Path(filename)
if filename.suffix == ".xdmf":
Expand Down Expand Up @@ -402,6 +407,8 @@ def read_point_data(
coordinate element (up to shape).
"""

logger.debug(f"Reading point data from {filename} with name {name} at time {time}")
logger.debug(f"Using backend {backend} with arguments {backend_args}")
backend_cls = get_backend(backend)
dataset, local_range_start = backend_cls.read_point_data(
filename=filename, name=name, comm=mesh.comm, time=time, backend_args=backend_args
Expand Down
5 changes: 4 additions & 1 deletion src/io4dolfinx/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#
# SPDX-License-Identifier: MIT

import logging
from pathlib import Path
from typing import Any

Expand All @@ -14,6 +15,7 @@
__all__ = [
"snapshot_checkpoint",
]
logger = logging.getLogger(__name__)


def snapshot_checkpoint(
Expand All @@ -31,7 +33,8 @@ def snapshot_checkpoint(
:param file: The file to write to or read from
:param mode: Either read or write
"""

logger.debug(f"Performing snapshot checkpoint with mode {mode} on file {file}")
logger.debug(f"Using backend {backend} with arguments {backend_args}")
backend_cls = get_backend(backend)
default_args = backend_cls.get_default_backend_args(backend_args)
if mode not in [FileMode.write, FileMode.read]:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
from pathlib import Path
from unittest.mock import Mock

from mpi4py import MPI

Expand Down Expand Up @@ -237,7 +238,7 @@ def g(x):
(io4dolfinx.read_attributes, ("nonexisting_file", MPI.COMM_WORLD, "")),
(io4dolfinx.read_timestamps, ("nonexisting_file", MPI.COMM_WORLD, "")),
(io4dolfinx.read_meshtags, ("nonexisting_file", MPI.COMM_WORLD, None, "")),
(io4dolfinx.read_function, ("nonexisting_file", None)),
(io4dolfinx.read_function, ("nonexisting_file", Mock())),
(io4dolfinx.read_mesh, ("nonexisting_file", MPI.COMM_WORLD)),
],
)
Expand Down