Skip to content

Commit

Permalink
Merge pull request #11 from esa/batching_improvement
Browse files Browse the repository at this point in the history
batching
  • Loading branch information
Sceki authored Jul 5, 2024
2 parents 69ba9e3 + d1bcea1 commit 152df44
Show file tree
Hide file tree
Showing 9 changed files with 613 additions and 409 deletions.
23 changes: 12 additions & 11 deletions doc/notebooks/covariance_propagation.ipynb

Large diffs are not rendered by default.

545 changes: 273 additions & 272 deletions doc/notebooks/sgp4_partial_derivatives.ipynb

Large diffs are not rendered by default.

21 changes: 10 additions & 11 deletions doc/notebooks/tle_propagation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -53,7 +53,7 @@
" 5.5809e-01, 6.2651e-02, 4.8993e+00])"
]
},
"execution_count": 10,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -126,34 +126,33 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"#we first need to prepare the data, the API requires that there are as many TLEs as times. Let us assume we want to\n",
"#propagate each of the \n",
"tles_=[]\n",
"for tle in tles:\n",
" tles_+=[tle]*10000\n",
"tsinces = torch.cat([torch.linspace(0,24*60,10000)]*len(tles))\n",
"#first let's initialize them:\n",
"dsgp4.initialize_tle(tles)\n",
"#then let's construct the TLEs batch by making sure there are as many TLEs as times:\n",
"tles_batch=[]\n",
"for tle in tles:\n",
" tles_batch+=[tle]*10000"
"_,tle_batch=dsgp4.initialize_tle(tles_)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"#we propagate the batch of 3,000 TLEs for 1 day:\n",
"states_teme=dsgp4.propagate_batch(tles_batch,tsinces)"
"states_teme=dsgp4.propagate_batch(tle_batch,tsinces)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 16,
"metadata": {},
"outputs": [
{
Expand Down
3 changes: 2 additions & 1 deletion dsgp4/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
__version__ = '0.1.2'
__version__ = '1.0.0'

import torch
torch.set_default_dtype(torch.float64)
from .sgp4 import sgp4
from .initl import initl
from .sgp4init import sgp4init
from .sgp4init_batch import sgp4init_batch
from .newton_method import newton_method, update_TLE
from .sgp4_batched import sgp4_batched
from .util import propagate, initialize_tle, propagate_batch
Expand Down
95 changes: 10 additions & 85 deletions dsgp4/sgp4_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy
from .tle import TLE

def sgp4_batched(satellite, tsince):
def sgp4_batched(satellite_batch, tsince):
"""
This function represents the batch SGP4 propagator.
It resembles `sgp4`, but accepts batches of TLEs.
Expand All @@ -12,99 +12,25 @@ def sgp4_batched(satellite, tsince):
in km and km/s, respectively, after `tsince` minutes.
Args:
- satellite (``dsgp4.tle.TLE``): TLE object
- tsince (``torch.tensor``): time to propagate, since the TLE epoch, in minutes
- satellite (``dsgp4.tle.TLE``): TLE batch object (with attributes that are N-dimensional tensors)
- tsince (``torch.tensor``): time to propagate, since the TLE epoch, in minutes (also an N-dimensional tensor)
Returns:
- batch_state (``torch.tensor``): a batch of 2x3 tensors, where the first row represents the spacecraft
position (in km) and the second the spacecraft velocity (in km/s)
"""
if not isinstance(satellite, list):
raise ValueError("satellite should be a list of TLE objects.")
if not isinstance(satellite[0],TLE):
raise ValueError("satellite should be a list of TLE objects.")
if not isinstance(satellite_batch, TLE):
raise ValueError("satellite_batch should be a TLE object.")
if not torch.is_tensor(tsince):
raise ValueError("tsince must be a tensor.")
if tsince.ndim!=1:
raise ValueError("tsince should be a one dimensional tensor.")
if len(tsince)!=len(satellite):
raise ValueError("in batch mode, tsince and satellite shall be of same length.")
if not hasattr(satellite[0], '_radiusearthkm'):
raise AttributeError('It looks like the satellite has not been initialized. Please use the `initialize_tle` method or directly `sgp4init` to initialize the satellite. Otherwise, if you are propagating, another option is to use `dsgp4.propagate` and pass `initialized=True` in the arguments.')
if len(tsince)!=len(satellite_batch._argpo):
raise ValueError(f"in batch mode, tsince and satellite_batch shall have attributes of same length. Instead {len(tsince)} for time, and {len(satellite_batch._argpo)} for satellites' attributes found")
if not hasattr(satellite_batch, '_radiusearthkm'):
raise AttributeError('It looks like the satellite_batch has not been initialized. Please use the `initialize_tle` method or directly `sgp4init` to initialize the satellite_batch. Otherwise, if you are propagating, another option is to use `dsgp4.propagate` and pass `initialized=True` in the arguments.')

batch_size = len(satellite)

satellite_batch=satellite[0].copy()
satellite_batch._bstar=torch.stack([s._bstar for s in satellite])
satellite_batch._ndot=torch.stack([s._ndot for s in satellite])
satellite_batch._nddot=torch.stack([s._nddot for s in satellite])
satellite_batch._ecco=torch.stack([s._ecco for s in satellite])
satellite_batch._argpo=torch.stack([s._argpo for s in satellite])
satellite_batch._inclo=torch.stack([s._inclo for s in satellite])
satellite_batch._mo=torch.stack([s._mo for s in satellite])

satellite_batch._no_kozai=torch.stack([s._no_kozai for s in satellite])
satellite_batch._nodeo=torch.stack([s._nodeo for s in satellite])
satellite_batch.satellite_catalog_number=torch.tensor([s.satellite_catalog_number for s in satellite])
satellite_batch._jdsatepoch=torch.stack([s._jdsatepoch for s in satellite])
satellite_batch._jdsatepochF=torch.stack([s._jdsatepochF for s in satellite])
satellite_batch._isimp=torch.tensor([s._isimp for s in satellite])
satellite_batch._method=[s._method for s in satellite]

satellite_batch._mdot=torch.stack([s._mdot for s in satellite])
satellite_batch._argpdot=torch.stack([s._argpdot for s in satellite])
satellite_batch._nodedot=torch.stack([s._nodedot for s in satellite])
satellite_batch._nodecf=torch.stack([s._nodecf for s in satellite])
satellite_batch._cc1=torch.stack([s._cc1 for s in satellite])
satellite_batch._cc4=torch.stack([s._cc4 for s in satellite])
satellite_batch._cc5=torch.stack([s._cc5 for s in satellite])
satellite_batch._t2cof=torch.stack([s._t2cof for s in satellite])

satellite_batch._omgcof=torch.stack([s._omgcof for s in satellite])
satellite_batch._eta=torch.stack([s._eta for s in satellite])
satellite_batch._xmcof=torch.stack([s._xmcof for s in satellite])
satellite_batch._delmo=torch.stack([s._delmo for s in satellite])
satellite_batch._d2=torch.stack([s._d2 for s in satellite])
satellite_batch._d3=torch.stack([s._d3 for s in satellite])
satellite_batch._d4=torch.stack([s._d4 for s in satellite])
satellite_batch._cc5=torch.stack([s._cc5 for s in satellite])
satellite_batch._sinmao=torch.stack([s._sinmao for s in satellite])
satellite_batch._t3cof=torch.stack([s._t3cof for s in satellite])
satellite_batch._t4cof=torch.stack([s._t4cof for s in satellite])
satellite_batch._t5cof=torch.stack([s._t5cof for s in satellite])

satellite_batch._xke=torch.stack([s._xke for s in satellite])
satellite_batch._radiusearthkm=torch.stack([s._radiusearthkm for s in satellite])
satellite_batch._t=torch.stack([s._t for s in satellite])
satellite_batch._aycof=torch.stack([s._aycof for s in satellite])
satellite_batch._x1mth2=torch.stack([s._x1mth2 for s in satellite])
satellite_batch._con41=torch.stack([s._con41 for s in satellite])
satellite_batch._x7thm1=torch.stack([s._x7thm1 for s in satellite])
satellite_batch._xlcof=torch.stack([s._xlcof for s in satellite])
satellite_batch._tumin=torch.stack([s._tumin for s in satellite])
satellite_batch._mu=torch.stack([s._mu for s in satellite])
satellite_batch._j2=torch.stack([s._j2 for s in satellite])
satellite_batch._j3=torch.stack([s._j3 for s in satellite])
satellite_batch._j4=torch.stack([s._j4 for s in satellite])
satellite_batch._j3oj2=torch.stack([s._j3oj2 for s in satellite])
satellite_batch._error=torch.stack([s._error for s in satellite])
satellite_batch._operationmode=[s._operationmode for s in satellite]
satellite_batch._satnum=torch.tensor([s._satnum for s in satellite])
satellite_batch._am=torch.stack([s._am for s in satellite])
satellite_batch._em=torch.stack([s._em for s in satellite])
satellite_batch._im=torch.stack([s._im for s in satellite])
satellite_batch._Om=torch.stack([s._Om for s in satellite])
satellite_batch._mm=torch.stack([s._mm for s in satellite])
satellite_batch._nm=torch.stack([s._nm for s in satellite])
satellite_batch._init=[s._init for s in satellite]

satellite_batch._no_unkozai=torch.stack([s._no_unkozai for s in satellite])
satellite_batch._a=torch.stack([s._a for s in satellite])
satellite_batch._alta=torch.stack([s._altp for s in satellite])




batch_size = len(tsince)
mrt = torch.zeros(batch_size)
x2o3 = torch.tensor(2.0 / 3.0)

Expand All @@ -125,7 +51,6 @@ def sgp4_batched(satellite, tsince):
tempe1 = satellite_batch._bstar * satellite_batch._cc4 * satellite_batch._t
templ1 = satellite_batch._t2cof * t2



delomg = satellite_batch._omgcof * satellite_batch._t

Expand Down
6 changes: 3 additions & 3 deletions dsgp4/sgp4init.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def sgp4init(
temp4 = torch.tensor(1.5e-12)

# ----------- set all near earth variables to zero ------------
satellite._isimp = torch.tensor(0); satellite._method = 'n'; satellite._aycof = torch.tensor(0.0);
satellite._isimp = torch.tensor(0); satellite._method = 'n'; satellite._aycof = torch.tensor(0.0);
satellite._con41 = torch.tensor(0.0); satellite._cc1 = torch.tensor(0.0); satellite._cc4 = torch.tensor(0.0);
satellite._cc5 = torch.tensor(0.0); satellite._d2 = torch.tensor(0.0); satellite._d3 = torch.tensor(0.0);
satellite._d4 = torch.tensor(0.0); satellite._delmo = torch.tensor(0.0); satellite._eta = torch.tensor(0.0);
Expand Down Expand Up @@ -198,6 +198,6 @@ def sgp4init(
12.0 * satellite._cc1 * satellite._d3 +
6.0 * satellite._d2 * satellite._d2 +
15.0 * cc1sq * (2.0 * satellite._d2 + cc1sq))
sgp4(satellite, torch.zeros(1,1))
sgp4(satellite, torch.zeros(1,1));

satellite._init = 'n'
satellite._init = 'n'
Loading

0 comments on commit 152df44

Please sign in to comment.