Skip to content

Commit

Permalink
Merge pull request #484 from cogmission/potential_radius_fix
Browse files Browse the repository at this point in the history
Added fix for potential radius
  • Loading branch information
rhyolight authored Sep 28, 2016
2 parents 820eb63 + 831607c commit b99fd45
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 12 deletions.
29 changes: 28 additions & 1 deletion src/main/java/org/numenta/nupic/Connections.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
import org.numenta.nupic.model.ProximalDendrite;
import org.numenta.nupic.model.Segment;
import org.numenta.nupic.model.Synapse;
import org.numenta.nupic.network.Persistence;
import org.numenta.nupic.network.PersistenceAPI;
import org.numenta.nupic.serialize.SerialConfig;
import org.numenta.nupic.util.AbstractSparseBinaryMatrix;
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.FlatMatrix;
import org.numenta.nupic.util.SparseMatrix;
import org.numenta.nupic.util.SparseObjectMatrix;
Expand All @@ -69,6 +73,10 @@ public class Connections implements Persistable {
private static final double EPSILON = 0.00001;

/////////////////////////////////////// Spatial Pooler Vars ///////////////////////////////////////////
/** <b>WARNING:</b> potentialRadius **must** be set to
* the inputWidth if using "globalInhibition" and if not
* using the Network API (which sets this automatically)
*/
private int potentialRadius = 16;
private double potentialPct = 0.5;
private boolean globalInhibition = false;
Expand Down Expand Up @@ -242,12 +250,25 @@ public class Connections implements Persistable {
*/
public Connections() {}

/**
* Returns a deep copy of this {@code Connections} object.
* @return a deep copy of this {@code Connections}
*/
public Connections copy() {
PersistenceAPI api = Persistence.get(new SerialConfig());
byte[] myBytes = api.serializer().serialize(this);
return api.serializer().deSerialize(myBytes);
}

/**
* Sets the derived values of the {@link SpatialPooler}'s initialization.
*/
public void doSpatialPoolerPostInit() {
synPermBelowStimulusInc = synPermConnected / 10.0;
synPermTrimThreshold = synPermActiveInc / 2.0;
if(potentialRadius == -1) {
potentialRadius = ArrayUtils.product(inputDimensions);
}
}

/////////////////////////////////////////
Expand Down Expand Up @@ -480,6 +501,11 @@ public void setNumColumns(int n) {
* parameter defines a square (or hyper square) area: a
* column will have a max square potential pool with
* sides of length 2 * potentialRadius + 1.
*
* <b>WARNING:</b> potentialRadius **must** be set to
* the inputWidth if using "globalInhibition" and if not
* using the Network API (which sets this automatically)
*
*
* @param potentialRadius
*/
Expand All @@ -489,11 +515,12 @@ public void setPotentialRadius(int potentialRadius) {

/**
* Returns the configured potential radius
*
* @return the configured potential radius
* @see setPotentialRadius
*/
public int getPotentialRadius() {
return Math.min(numInputs, potentialRadius);
return potentialRadius;
}

/**
Expand Down
11 changes: 10 additions & 1 deletion src/main/java/org/numenta/nupic/Parameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public class Parameters implements Persistable {
//////////// Spatial Pooler Parameters ///////////
Map<KEY, Object> defaultSpatialParams = new ParametersMap();
defaultSpatialParams.put(KEY.INPUT_DIMENSIONS, new int[]{64});
defaultSpatialParams.put(KEY.POTENTIAL_RADIUS, 16);
defaultSpatialParams.put(KEY.POTENTIAL_RADIUS, -1);
defaultSpatialParams.put(KEY.POTENTIAL_PCT, 0.5);
defaultSpatialParams.put(KEY.GLOBAL_INHIBITION, false);
defaultSpatialParams.put(KEY.INHIBITION_RADIUS, 0);
Expand Down Expand Up @@ -225,6 +225,10 @@ public static enum KEY {

/////////// Spatial Pooler Parameters ///////////
INPUT_DIMENSIONS("inputDimensions", int[].class),
/** <b>WARNING:</b> potentialRadius **must** be set to
* the inputWidth if using "globalInhibition" and if not
* using the Network API (which sets this automatically)
*/
POTENTIAL_RADIUS("potentialRadius", Integer.class),
POTENTIAL_PCT("potentialPct", Double.class), //TODO add range here?
GLOBAL_INHIBITION("globalInhibition", Boolean.class),
Expand Down Expand Up @@ -770,6 +774,11 @@ public void setInputDimensions(int[] inputDimensions) {
* parameter defines a square (or hyper square) area: a
* column will have a max square potential pool with
* sides of length 2 * potentialRadius + 1.
*
* <b>WARNING:</b> potentialRadius **must** be set to
* the inputWidth if using "globalInhibition" and if not
* using the Network API (which sets this automatically)
*
*
* @param potentialRadius
*/
Expand Down
6 changes: 2 additions & 4 deletions src/main/java/org/numenta/nupic/network/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,7 @@ public Layer<T> close() {
params.setInputDimensions(upstreamDims);
connections.setInputDimensions(upstreamDims);
} else if(parentRegion != null && parentNetwork != null
&& parentRegion.equals(parentNetwork.getSensorRegion()) && encoder == null
&& spatialPooler != null) {

&& parentRegion.equals(parentNetwork.getSensorRegion()) && encoder == null && spatialPooler != null) {
Layer<?> curr = this;
while((curr = curr.getPrevious()) != null) {
if(curr.getEncoder() != null) {
Expand Down Expand Up @@ -692,7 +690,7 @@ public Subscription subscribe(final Observer<Inference> subscriber) {

return createSubscription(subscriber);
}

/**
* Allows the user to define the {@link Connections} object data structure
* to use. Or possibly to share connections between two {@code Layer}s
Expand Down
17 changes: 16 additions & 1 deletion src/test/java/org/numenta/nupic/ConnectionsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,23 @@
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.MersenneTwister;

import com.cedarsoftware.util.DeepEquals;


public class ConnectionsTest {
@Test
public void testCopy() {
Parameters retVal = Parameters.getTemporalDefaultParameters();
retVal.set(KEY.COLUMN_DIMENSIONS, new int[] { 32 });
retVal.set(KEY.CELLS_PER_COLUMN, 4);

Connections connections = new Connections();

retVal.apply(connections);
TemporalMemory.init(connections);

assertTrue(DeepEquals.deepEquals(connections, connections.copy()));
}

@Test
public void testCreateSegment() {
Expand Down Expand Up @@ -574,7 +589,7 @@ public void testGetPrintString() {
TemporalMemory.init(con);

String output = con.getPrintString();
assertEquals(1370, output.length());
assertEquals(1371, output.length());

Set<String> fieldSet = Parameters.getEncoderDefaultParameters().keys().stream().
map(k -> k.getFieldName()).collect(Collectors.toCollection(LinkedHashSet::new));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ private void initSP() {
parameters.apply(mem);
sp.init(mem);
}

@Test
public void confirmSPConstruction() {
setupParameters();
Expand Down
30 changes: 27 additions & 3 deletions src/test/java/org/numenta/nupic/network/NetworkTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ public void testBasicNetworkRunAWhileThenHalt() {
@Test
public void testRegionHierarchies() {
Parameters p = NetworkTestHarness.getParameters();
p.setPotentialRadius(16);
p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
p.set(KEY.RANDOM, new MersenneTwister(42));

Expand Down Expand Up @@ -1007,6 +1008,29 @@ public void testObservableWithCoordinateEncoder_NEGATIVE() {
assertTrue(hasErrors(tester));
}

@Test
public void testPotentialRadiusFollowsInputWidth() {
Parameters p = NetworkTestHarness.getParameters();
p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
p.set(KEY.INPUT_DIMENSIONS, new int[] { 200 });
p.set(KEY.RANDOM, new MersenneTwister(42));

Network network = Network.create("test network", p)
.add(Network.createRegion("r1")
.add(Network.createLayer("2", p)
.add(Anomaly.create())
.add(new TemporalMemory())
.add(new SpatialPooler())
.close()));

Region r1 = network.lookup("r1");
Layer<?> layer2 = r1.lookup("2");

int width = layer2.calculateInputWidth();
assertEquals(200, width);
assertEquals(200, layer2.getConnections().getPotentialRadius());
}

///////////////////////////////////////////////////////////////////////////////////
// Tests of Calculate Input Width for inter-regional and inter-layer calcs //
///////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1063,7 +1087,6 @@ public void testCalculateInputWidth_NoPrevLayer_UpstreamRegion_without_TM() {

int width = layer2.calculateInputWidth();
assertEquals(2048, width);

}

@Test
Expand All @@ -1077,7 +1100,6 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andTM() {
.add(Network.createLayer("2", p)
.add(Anomaly.create())
.add(new TemporalMemory())
//.add(new SpatialPooler())
.close()));

Region r1 = network.lookup("r1");
Expand All @@ -1098,14 +1120,15 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andSPTM() {
.add(Network.createLayer("2", p)
.add(Anomaly.create())
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(new SpatialPooler())
.close()));

Region r1 = network.lookup("r1");
Layer<?> layer2 = r1.lookup("2");

int width = layer2.calculateInputWidth();
assertEquals(8, width);
assertEquals(8, layer2.getConnections().getPotentialRadius());
}

@Test
Expand All @@ -1126,6 +1149,7 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andNoTM() {

int width = layer2.calculateInputWidth();
assertEquals(8, width);
assertEquals(8, layer2.getConnections().getPotentialRadius());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ public static Parameters getParameters() {
parameters.set(KEY.CELLS_PER_COLUMN, 6);

//SpatialPooler specific
parameters.set(KEY.POTENTIAL_RADIUS, 12);//3
parameters.set(KEY.POTENTIAL_RADIUS, -1);//3
parameters.set(KEY.POTENTIAL_PCT, 0.5);//0.5
parameters.set(KEY.GLOBAL_INHIBITION, false);
parameters.set(KEY.LOCAL_AREA_DENSITY, -1.0);
Expand Down

0 comments on commit b99fd45

Please sign in to comment.