Skip to content

Commit

Permalink
move dataloader code to its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
axsk committed Jul 11, 2023
1 parent 2c8b04a commit 150df45
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/ISOKANN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ include("performance.jl") # performance metric loggers
include("benchmarks.jl") # benchmark runs, deprecated by scripts/*

include("cuda.jl") # fixes for cuda
include("dataloader.jl")


include("reactionpath.jl")
Expand Down
10 changes: 1 addition & 9 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,7 @@ function datasubsample(model, data, nx)
return xs, ys
end

# this allows to use the current isokann implementation with a DataLoader
# where (according to MLUtils practice) the ys have shape (dim x nkoop x npoints)
# we therefore permute the last dims to adhere to the ISOKANN.jl standard
using MLUtils
function datasubsample(model, data::DataLoader, nx)
x, y = first(data)
y = permutedims(y, (1, 3, 2))
return (x, y)
end



function subsample_inds(model, xs, n)
Expand Down
9 changes: 9 additions & 0 deletions src/dataloader.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# this allows to use the current isokann implementation with a DataLoader
# where (according to MLUtils practice) the ys have shape (dim x nkoop x npoints)
# we therefore permute the last dims to adhere to the ISOKANN.jl standard
using MLUtils
function datasubsample(model, data::DataLoader, nx)
x, y = first(data)
y = permutedims(y, (1, 3, 2))
return (x, y)
end

0 comments on commit 150df45

Please sign in to comment.