Skip to content

Commit

Permalink
Add more attributes to typed variables
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Oct 23, 2024
1 parent 1136cb2 commit 3f0caa6
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 16 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

Please add your functional changes to the appropriate section in the PR.
Keep it human-readable, your future self will thank you!

## [Unreleased]

### Changed

- Add more attributes to typed variables
14 changes: 11 additions & 3 deletions dev/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

mars = source_factory(
"mars",
)

r = dict(
param=["u", "v", "t", "q"],
grid=[20, 20],
date="20200101/to/20200105",
levelist=[1000, 850, 500],
)

data = mars.forward(None)
data = mars.forward(r)

for f in data:
print(f)
Expand All @@ -35,11 +38,16 @@
################

pipeline = workflow_factory("pipeline", filters=[mars, uv_2_ddff, ddff_2_uv])
for f in pipeline(None):
for f in pipeline(r):
print(f)

################
pipeline = mars | uv_2_ddff | ddff_2_uv


pipeline = r | mars | uv_2_ddff | ddff_2_uv

for f in pipeline:
print(f)


ipipe = pipeline.to_infernece()
3 changes: 1 addition & 2 deletions src/anemoi/transform/grids/unstructured.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(self, latitudes, longitudes, uuidOfHGrid=None):
assert isinstance(latitudes, np.ndarray), type(latitudes)
assert isinstance(longitudes, np.ndarray), type(longitudes)

LOG.info(f"Latitudes: {len(latitudes)}, Longitudes: {len(longitudes)}")
assert len(latitudes) == len(longitudes)

self.uuidOfHGrid = uuidOfHGrid
Expand Down Expand Up @@ -95,7 +94,7 @@ def from_grib(cls, latitudes_url_or_path, longitudes_url_or_path, latitudes_para
return cls([UnstructuredGridField(Geography(latitudes, longitudes))])

@classmethod
def from_values(cls, latitudes, longitudes):
def from_values(cls, *, latitudes, longitudes):
if isinstance(latitudes, (list, tuple)):
latitudes = np.array(latitudes)

Expand Down
12 changes: 9 additions & 3 deletions src/anemoi/transform/sources/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@ class Mars(Source):
"""A demo source"""

def __init__(self, **request):
self.request = request
pass

def forward(self, data):
assert data is None
return ekd.from_source("mars", **data)

return ekd.from_source("mars", **self.request)
def __ror__(self, data):

class Input:
def __init__(self, data):
self.data = data

return Input(data)


register_source("mars", Mars)
17 changes: 10 additions & 7 deletions src/anemoi/transform/variables/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import json

from . import Variable

Expand All @@ -16,11 +15,7 @@ class VariableFromMarsVocabulary(Variable):
def __init__(self, name, data: dict) -> None:
super().__init__(name)
self.data = data
print(json.dumps(data, indent=4))
if "mars" in self.data:
self.mars = self.data["mars"]
else:
self.mars = self.data
self.mars = self.data.get("mars", {})

@property
def is_pressure_level(self):
Expand All @@ -32,7 +27,15 @@ def level(self):

@property
def is_constant_in_time(self):
return self.data.get("is_constant_in_time", False)
return self.data.get("constant_in_time", False)

@property
def is_from_input(self):
return "mars" in self.data

@property
def is_computed_forcing(self):
return self.data.get("computed_forcing", False)


class VariableFromDict(VariableFromMarsVocabulary):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
tlon = "tlon"


def test_unstructured_from_url():
def do_not_test_unstructured_from_url():
ds = UnstructuredGridFieldList.from_grib(latitude_url, longitudes_url, tlat, tlon)

assert len(ds) == 1
Expand Down

0 comments on commit 3f0caa6

Please sign in to comment.