Skip to content

Commit

Permalink
CLN: Add types to _roff_parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
mferrera committed Nov 23, 2023
1 parent aa12f39 commit 80049cc
Showing 1 changed file with 44 additions and 31 deletions.
75 changes: 44 additions & 31 deletions src/xtgeo/grid3d/_roff_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import warnings
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import roffio

from xtgeo.common.constants import UNDEF_INT_LIMIT, UNDEF_LIMIT

if TYPE_CHECKING:
from xtgeo.common.types import FileLike

from .grid_property import GridProperty


@dataclass
class RoffParameter:
Expand Down Expand Up @@ -91,9 +96,10 @@ def undefined_value(self) -> int | float:
return -999
if np.issubdtype(self.values.dtype, np.floating):
return -999.0
raise ValueError(f"Parameter values of unsupported type {type(self.values)}")

@property
def is_discrete(self):
def is_discrete(self) -> bool:
"""
Returns:
True if the RoffParameter is a discrete type
Expand All @@ -102,7 +108,7 @@ def is_discrete(self):
self.values.dtype, np.integer
)

def xtgeo_codes(self):
def xtgeo_codes(self) -> dict[int, str]:
"""
Returns:
The discrete codes of the parameter in the format of
Expand All @@ -113,29 +119,25 @@ def xtgeo_codes(self):
else:
return dict()

def xtgeo_values(self):
def xtgeo_values(self) -> np.ndarray:
"""
Args:
The value to use for undefined. Defaults to that defined by
roff.
Returns:
The values in the format of xtgeo grid property
"""
vals = self.values
if isinstance(vals, bytes):
vals = np.ndarray(len(vals), np.uint8, vals)
vals = vals.copy()
vals = np.flip(vals.reshape((self.nx, self.ny, self.nz)), -1)

if self.is_discrete:
vals = vals.astype(np.int32)
if isinstance(self.values, bytes):
vals: np.ndarray = np.ndarray(len(self.values), np.uint8, self.values)
else:
vals = vals.astype(np.float64)
vals = self.values.copy()

vals = np.flip(vals.reshape((self.nx, self.ny, self.nz)), -1)
vals = vals.astype(np.int32) if self.is_discrete else vals.astype(np.float64)
return np.ma.masked_values(vals, self.undefined_value)

@staticmethod
def from_xtgeo_grid_property(xtgeo_grid_property):
def from_xtgeo_grid_property(xtgeo_grid_property: GridProperty) -> RoffParameter:
"""
Args:
xtgeo_grid_property (xtgeo.GridProperty): Any xtgeo.GridProperty
Expand Down Expand Up @@ -166,27 +168,31 @@ def from_xtgeo_grid_property(xtgeo_grid_property):
values = values.astype(np.float32).filled(-999.0)

return RoffParameter(
*xtgeo_grid_property.dimensions,
nx=xtgeo_grid_property.ncol,
ny=xtgeo_grid_property.nrow,
nz=xtgeo_grid_property.nlay,
name=xtgeo_grid_property.name,
values=np.asarray(np.flip(values, -1).ravel()),
code_names=code_names,
code_values=code_values,
)

def to_file(self, filelike, roff_format=roffio.Format.BINARY):
def to_file(
self,
filelike: FileLike,
roff_format: roffio.Format = roffio.Format.BINARY,
) -> None:
"""
Writes the RoffParameter to a roff file
Args:
filelike (str or byte stream): The file to write to.
roff_format (roffio.Format): The format to write the file in.
"""
data = dict(
{
"filedata": {"filetype": "parameter"},
"dimensions": {"nX": self.nx, "nY": self.ny, "nZ": self.nz},
"parameter": {"name": self.name},
}
)
data: dict[str, dict[str, Any]] = {
"filedata": {"filetype": "parameter"},
"dimensions": {"nX": self.nx, "nY": self.ny, "nZ": self.nz},
"parameter": {"name": self.name},
}
if self.code_names is not None:
data["parameter"]["codeNames"] = list(self.code_names)
if self.code_values is not None:
Expand All @@ -197,7 +203,7 @@ def to_file(self, filelike, roff_format=roffio.Format.BINARY):
roffio.write(filelike, data, roff_format=roff_format)

@staticmethod
def from_file(filelike, name=None):
def from_file(filelike: FileLike, name: str | None = None) -> RoffParameter:
"""
Read a RoffParameter from a roff file
Args:
Expand All @@ -208,7 +214,7 @@ def from_file(filelike, name=None):
The RoffGrid in the roff file.
"""

def should_skip_parameter(tag, key):
def should_skip_parameter(tag: str, key: str) -> bool:
if tag == "parameter" and key[0] == "name":
if name is None or key[1] == name:
return False
Expand Down Expand Up @@ -281,11 +287,18 @@ def should_skip_parameter(tag, key):
f" have filetype parameter or grid, found {filetype}"
)

roff: dict[str, Any] = {}
for tag, tag_keys in translate_kws.items():
for key, translated in tag_keys.items():
if found[tag][key] is not None:
roff[translated] = found[tag][key]

return RoffParameter(
**{
translated: found[tag][key]
for tag, tag_keys in translate_kws.items()
for key, translated in tag_keys.items()
if found[tag][key] is not None
}
nx=roff["nx"],
ny=roff["ny"],
nz=roff["nz"],
name=roff["name"],
values=roff["values"],
code_names=roff.get("code_names", None),
code_values=roff.get("code_values", None),
)

0 comments on commit 80049cc

Please sign in to comment.