Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support user-specified callbacks after each SOM training epoch #183

Merged
merged 1 commit into from
Apr 1, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 26 additions & 39 deletions src/analysis/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,19 @@ end


"""
trainGigaSOM(som::Som, dInfo::LoadedDataInfo;
kernelFun::Function = gaussianKernel,
metric = Euclidean(),
somDistFun = distMatrix(Chebyshev()),
knnTreeFun = BruteTree,
rStart = 0.0, rFinal=0.1, radiusFun=expRadius(-5.0),
epochs = 20)
trainGigaSOM(
som::Som,
dInfo::LoadedDataInfo;
kernelFun::Function = gaussianKernel,
metric = Euclidean(),
somDistFun = distMatrix(Chebyshev()),
knnTreeFun = BruteTree,
rStart = 0.0,
rFinal = 0.1,
radiusFun = expRadius(-5.0),
epochs = 20,
eachEpoch = (e, r, codes) -> nothing,
)

# Arguments:
- `som`: object of type Som with an initialised som
Expand All @@ -102,6 +108,9 @@ end
- `rFinal`: target radius at the last epoch, defaults to 0.1
- `radiusFun`: Function that generates radius decay, e.g. `linearRadius` or `expRadius(10.0)`
- `epochs`: number of SOM training iterations (default 10)
- `eachEpoch`: a function to call back after each epoch, accepting arguments
`(epochNumber, radius, codes)`. For simplicity, this gets additionally called
once before the first epoch, with `epochNumber` set to zero.
"""
function trainGigaSOM(
som::Som,
Expand All @@ -114,6 +123,7 @@ function trainGigaSOM(
rFinal = 0.1,
radiusFun = expRadius(-5.0),
epochs = 20,
eachEpoch = (e, r, codes) -> nothing,
)

# set the default radius
Expand All @@ -127,6 +137,8 @@ function trainGigaSOM(

codes = som.codes

eachEpoch(0, rStart, codes)

for j = 1:epochs
@debug "Epoch $j..."

Expand All @@ -145,6 +157,8 @@ function trainGigaSOM(

wEpoch = kernelFun(dm, r)
codes = (wEpoch * numerator) ./ (wEpoch * denominator)

eachEpoch(epoch, r, codes)
end

som.codes = copy(codes)
Expand All @@ -154,46 +168,19 @@ end

"""
trainGigaSOM(som::Som, train;
kernelFun::Function = gaussianKernel,
metric = Euclidean(),
somDistFun = distMatrix(Chebyshev()),
knnTreeFun = BruteTree,
rStart = 0.0, rFinal=0.1, radiusFun=expRadius(-5.0),
epochs = 20)
kwargs...)

Overload of `trainGigaSOM` for simple DataFrames and matrices. This slices the
data using `DistributedArrays`, sends them the workers, and runs normal
`trainGigaSOM`. Data is `undistribute`d after the computation.
data, distributes them to the workers, and runs normal `trainGigaSOM`. Data is
`undistribute`d after the computation.
"""
function trainGigaSOM(
som::Som,
train;
kernelFun::Function = gaussianKernel,
metric = Euclidean(),
somDistFun = distMatrix(Chebyshev()),
knnTreeFun = BruteTree,
rStart = 0.0,
rFinal = 0.1,
radiusFun = expRadius(-5.0),
epochs = 20,
)
function trainGigaSOM(som::Som, train; kwargs...)

train = Matrix{Float64}(train)

#this slices the data into parts and and sends them to workers
dInfo = distribute_array(:GigaSOMtrainDataVar, train, workers())
som_res = trainGigaSOM(
som,
dInfo,
kernelFun = kernelFun,
metric = metric,
somDistFun = somDistFun,
knnTreeFun = knnTreeFun,
rStart = rStart,
rFinal = rFinal,
radiusFun = radiusFun,
epochs = epochs,
)
som_res = trainGigaSOM(som, dInfo; kwargs...)
undistribute(dInfo)
return som_res
end
Expand Down