-
Notifications
You must be signed in to change notification settings - Fork 0
/
udf_segmentation.py
150 lines (122 loc) · 5.22 KB
/
udf_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import functools
import gc
import sys
from typing import Dict
import numpy as np
import xarray as xr
from openeo.udf import inspect
from xarray.ufuncs import isnan as ufuncs_isnan
from xarray import DataArray
# Add the onnx dependencies to the path
sys.path.insert(0, "onnx_deps")
import onnxruntime as ort
model_names = frozenset([
"BelgiumCropMap_unet_3BandsGenerator_Network1.onnx",
"BelgiumCropMap_unet_3BandsGenerator_Network2.onnx",
"BelgiumCropMap_unet_3BandsGenerator_Network3.onnx",
])
@functools.lru_cache(maxsize=25)
def load_ort_sessions(names):
"""
Load the models and make the prediction functions.
The lru_cache avoids loading the model multiple times on the same worker.
@param modeldir: Model directory
@return: Loaded model sessions
"""
# inspect(message="Loading convolutional neural networks as ONNX runtime sessions ...")
return [
ort.InferenceSession(f"onnx_models/{model_name}")
for model_name in names
]
def process_window_onnx(ndvi_stack, patch_size=128):
## check whether you actually supplied 3 images or more
nr_valid_bands = ndvi_stack.shape[2]
## if the stack doesn't have at least 3 bands, we cannot process this window
if nr_valid_bands < 3:
inspect(message="Not enough input data for this window -> skipping!")
# gc.collect()
return None
## we'll do 12 predictions: use 3 networks, and for each randomly take 3 NDVI bands and repeat 4 times
ort_sessions = load_ort_sessions(model_names)
number_of_models = len(ort_sessions)
number_per_model = 4
prediction = np.zeros((patch_size, patch_size, number_of_models * number_per_model))
for model_counter in range(number_of_models):
## load the current model
ort_session = ort_sessions[model_counter]
## make 4 predictions per model
for i in range(number_per_model):
## define the input data
input_data = ndvi_stack[
:,
:,
np.random.choice(
np.arange(nr_valid_bands), size=3, replace=False
),
].reshape(1, patch_size * patch_size, 3)
ort_inputs = {
ort_session.get_inputs()[0].name: input_data
}
## make the prediction
ort_outputs = ort_session.run(None, ort_inputs)
prediction[:, :, i + number_per_model*model_counter] = np.squeeze(
ort_outputs[0]\
.reshape((patch_size, patch_size))
)
## free up some memory to avoid memory errors
gc.collect()
## final prediction is the median of all predictions per pixel
final_prediction = np.median(prediction, axis=2)
return final_prediction
def preprocess_datacube(cubearray):
## xarray nan to np nan (we will pass a numpy array)
cubearray = cubearray.where(~ufuncs_isnan(cubearray), np.nan)
## Transpose to format accepted by model and get the values
cubearray = cubearray.transpose("x", "bands", "y", "t")
cubearray = cubearray[:, 0, :, :]
ndvi_stack = cubearray.transpose("x", "y", "t").values
## Clamp out of range NDVI values
ndvi_stack = np.where(ndvi_stack < 0.92, ndvi_stack, 0.92)
ndvi_stack = np.where(ndvi_stack >-0.08, ndvi_stack, np.nan)
ndvi_stack = (ndvi_stack + 0.08)
## Create a mask where all valid values are 1 and all nans are 0
## This will be used for selecting the best images
mask_stack = np.ones_like(ndvi_stack, dtype=np.uint8)
mask_stack = np.where(np.isnan(ndvi_stack), 0, 1)
## Fill the NaN values with 0
ndvi_stack[mask_stack == 0] = 0
## Count the amount of invalid data per acquisition and sort accordingly
sum_invalid = np.sum(mask_stack == 0, axis=(0, 1))
## If we have enough clear images (without ANY missing values), we're good to go,
## and we will use all of them (could be more than 3!)
if len(np.where(sum_invalid == 0)[0]) > 3:
ndvi_stack = ndvi_stack[:, :, np.where(sum_invalid == 0)[0]]
## else we need to add some images that do contain some nan's;
## in this case we will select just the 3 best ones
else:
# inspect(f"Found {len(np.where(sum_invalid == 0)[0])} clear acquisitions -> appending some bad images as well!")
idxsorted = np.argsort(sum_invalid)
ndvi_stack = ndvi_stack[:, :, idxsorted[:3]]
## return the stack
return ndvi_stack
def apply_datacube(cube: DataArray, context: Dict) -> DataArray:
## preprocess the datacube
ndvi_stack = preprocess_datacube(cube)
## process the window
result = process_window_onnx(ndvi_stack)
## transform your numpy array predictions into an xarray
result = result.astype(np.float64)
result_xarray = xr.DataArray(
result,
dims=["x", "y"],
coords={"x": cube.coords["x"], "y": cube.coords["y"]},
)
## Reintroduce time and bands dimensions
result_xarray = result_xarray.expand_dims(
dim={
"t": [np.datetime64(str(cube.t.dt.year.values[0]) + "-01-01")],
"bands": ["prediction"],
},
)
## Return the resulting xarray
return result_xarray