Skip to content

Commit

Permalink
Merge pull request #78 from bararchy/master
Browse files Browse the repository at this point in the history
Fixed for 0.25.0
  • Loading branch information
ArtLinkov authored Jun 17, 2018
2 parents d1d6e44 + be5f027 commit 7adac95
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
8 changes: 5 additions & 3 deletions spec/network_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ describe SHAInet::Network do
xor = SHAInet::Network.new

xor.add_layer(:input, 2, "memory", SHAInet.sigmoid)
1.times { |x| xor.add_layer(:hidden, 3, "memory", SHAInet.sigmoid) }
1.times { |x| xor.add_layer(:hidden, 2, "memory", SHAInet.sigmoid) }
xor.add_layer(:output, 1, "memory", SHAInet.sigmoid)
xor.fully_connect

xor.learning_rate = 0.7
xor.learning_rate = 0.1
xor.momentum = 0.3

xor.train(
Expand Down Expand Up @@ -450,7 +450,7 @@ describe SHAInet::Network do
learning_rate: 0.5,
sigma: 0.1,
cost_function: :c_ent,
epochs: 150,
epochs: 1000,
mini_batch_size: 5,
error_threshold: 1e-9,
log_each: 100,
Expand Down Expand Up @@ -557,3 +557,5 @@ end

# Remove train data
system("cd #{__DIR__}/test_data && rm *.csv")
File.delete("my_net.nn")
File.delete("xor.nn")
29 changes: 16 additions & 13 deletions src/shainet/basic/network.cr
Original file line number Diff line number Diff line change
Expand Up @@ -338,17 +338,17 @@ module SHAInet
proc = get_cost_proc(cost_function.to_s)
cost_function = proc
end

loop do |e|
counter = 0_i64
loop do
# Show training progress of epochs
if e % log_each == 0
log_summary(e)
if counter % log_each == 0
log_summary(counter)
# @all_neurons.each { |s| puts s.gradient }
end

# Break condtitions
if e >= epochs || (error_threshold >= @mse) && (e > 1)
log_summary(e)
if counter >= epochs || (error_threshold >= @mse) && (counter > 1)
log_summary(counter)
break
end

Expand Down Expand Up @@ -429,10 +429,11 @@ module SHAInet

# Show training progress of the mini-batches
i += 1
if e % log_each == 0
if counter % log_each == 0
@logger.info("Slice: (#{i} / #{slices}), MSE: #{@mse}") if show_slice
# @logger.info("@error_signal: #{@error_signal}")
end
counter += 1
end
end
end
Expand Down Expand Up @@ -571,15 +572,16 @@ module SHAInet
cost_function = proc
end

loop do |e|
if e >= epochs || (error_threshold >= @mse) && (e > 1)
log_summary(e)
counter = 0_i64
loop do
if counter >= epochs || (error_threshold >= @mse) && (counter > 1)
log_summary(counter)
break
end

# Show training progress of epochs
if e % log_each == 0
log_summary(e)
if counter % log_each == 0
log_summary(counter)
end

# Counters for disply
Expand Down Expand Up @@ -621,10 +623,11 @@ module SHAInet

# Show training progress of the mini-batches
i += 1
if e % log_each == 0
if counter % log_each == 0
@logger.info("Slice: (#{i} / #{slices}), MSE: #{@mse}") if show_slice
# @logger.info("@error_signal: #{@error_signal}")
end
counter += 1
end
end
end
Expand Down
1 change: 0 additions & 1 deletion xor.nn

This file was deleted.

0 comments on commit 7adac95

Please sign in to comment.