diff --git a/src/main/java/org/numenta/nupic/Connections.java b/src/main/java/org/numenta/nupic/Connections.java
index 35f60973..91bd06d4 100644
--- a/src/main/java/org/numenta/nupic/Connections.java
+++ b/src/main/java/org/numenta/nupic/Connections.java
@@ -48,7 +48,11 @@
import org.numenta.nupic.model.ProximalDendrite;
import org.numenta.nupic.model.Segment;
import org.numenta.nupic.model.Synapse;
+import org.numenta.nupic.network.Persistence;
+import org.numenta.nupic.network.PersistenceAPI;
+import org.numenta.nupic.serialize.SerialConfig;
import org.numenta.nupic.util.AbstractSparseBinaryMatrix;
+import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.FlatMatrix;
import org.numenta.nupic.util.SparseMatrix;
import org.numenta.nupic.util.SparseObjectMatrix;
@@ -69,6 +73,10 @@ public class Connections implements Persistable {
private static final double EPSILON = 0.00001;
/////////////////////////////////////// Spatial Pooler Vars ///////////////////////////////////////////
+ /** WARNING: potentialRadius **must** be set to
+ * the inputWidth if using "globalInhibition" and if not
+ * using the Network API (which sets this automatically)
+ */
private int potentialRadius = 16;
private double potentialPct = 0.5;
private boolean globalInhibition = false;
@@ -242,12 +250,25 @@ public class Connections implements Persistable {
*/
public Connections() {}
+ /**
+ * Returns a deep copy of this {@code Connections} object.
+ * @return a deep copy of this {@code Connections}
+ */
+ public Connections copy() {
+ PersistenceAPI api = Persistence.get(new SerialConfig());
+ byte[] myBytes = api.serializer().serialize(this);
+ return api.serializer().deSerialize(myBytes);
+ }
+
/**
* Sets the derived values of the {@link SpatialPooler}'s initialization.
*/
public void doSpatialPoolerPostInit() {
synPermBelowStimulusInc = synPermConnected / 10.0;
synPermTrimThreshold = synPermActiveInc / 2.0;
+ if(potentialRadius == -1) {
+ potentialRadius = ArrayUtils.product(inputDimensions);
+ }
}
/////////////////////////////////////////
@@ -480,6 +501,11 @@ public void setNumColumns(int n) {
* parameter defines a square (or hyper square) area: a
* column will have a max square potential pool with
* sides of length 2 * potentialRadius + 1.
+ *
+ * WARNING: potentialRadius **must** be set to
+ * the inputWidth if using "globalInhibition" and if not
+ * using the Network API (which sets this automatically)
+ *
*
* @param potentialRadius
*/
@@ -489,11 +515,12 @@ public void setPotentialRadius(int potentialRadius) {
/**
* Returns the configured potential radius
+ *
* @return the configured potential radius
* @see setPotentialRadius
*/
public int getPotentialRadius() {
- return Math.min(numInputs, potentialRadius);
+ return potentialRadius;
}
/**
diff --git a/src/main/java/org/numenta/nupic/Parameters.java b/src/main/java/org/numenta/nupic/Parameters.java
index 32c5315a..d2ed2e15 100644
--- a/src/main/java/org/numenta/nupic/Parameters.java
+++ b/src/main/java/org/numenta/nupic/Parameters.java
@@ -97,7 +97,7 @@ public class Parameters implements Persistable {
//////////// Spatial Pooler Parameters ///////////
Map defaultSpatialParams = new ParametersMap();
defaultSpatialParams.put(KEY.INPUT_DIMENSIONS, new int[]{64});
- defaultSpatialParams.put(KEY.POTENTIAL_RADIUS, 16);
+ defaultSpatialParams.put(KEY.POTENTIAL_RADIUS, -1);
defaultSpatialParams.put(KEY.POTENTIAL_PCT, 0.5);
defaultSpatialParams.put(KEY.GLOBAL_INHIBITION, false);
defaultSpatialParams.put(KEY.INHIBITION_RADIUS, 0);
@@ -225,6 +225,10 @@ public static enum KEY {
/////////// Spatial Pooler Parameters ///////////
INPUT_DIMENSIONS("inputDimensions", int[].class),
+ /** WARNING: potentialRadius **must** be set to
+ * the inputWidth if using "globalInhibition" and if not
+ * using the Network API (which sets this automatically)
+ */
POTENTIAL_RADIUS("potentialRadius", Integer.class),
POTENTIAL_PCT("potentialPct", Double.class), //TODO add range here?
GLOBAL_INHIBITION("globalInhibition", Boolean.class),
@@ -770,6 +774,11 @@ public void setInputDimensions(int[] inputDimensions) {
* parameter defines a square (or hyper square) area: a
* column will have a max square potential pool with
* sides of length 2 * potentialRadius + 1.
+ *
+ * WARNING: potentialRadius **must** be set to
+ * the inputWidth if using "globalInhibition" and if not
+ * using the Network API (which sets this automatically)
+ *
*
* @param potentialRadius
*/
diff --git a/src/main/java/org/numenta/nupic/network/Layer.java b/src/main/java/org/numenta/nupic/network/Layer.java
index 6044bda1..aa9d6599 100644
--- a/src/main/java/org/numenta/nupic/network/Layer.java
+++ b/src/main/java/org/numenta/nupic/network/Layer.java
@@ -495,9 +495,7 @@ public Layer close() {
params.setInputDimensions(upstreamDims);
connections.setInputDimensions(upstreamDims);
} else if(parentRegion != null && parentNetwork != null
- && parentRegion.equals(parentNetwork.getSensorRegion()) && encoder == null
- && spatialPooler != null) {
-
+ && parentRegion.equals(parentNetwork.getSensorRegion()) && encoder == null && spatialPooler != null) {
Layer> curr = this;
while((curr = curr.getPrevious()) != null) {
if(curr.getEncoder() != null) {
@@ -692,7 +690,7 @@ public Subscription subscribe(final Observer subscriber) {
return createSubscription(subscriber);
}
-
+
/**
* Allows the user to define the {@link Connections} object data structure
* to use. Or possibly to share connections between two {@code Layer}s
diff --git a/src/test/java/org/numenta/nupic/ConnectionsTest.java b/src/test/java/org/numenta/nupic/ConnectionsTest.java
index abdd8564..8cb18589 100644
--- a/src/test/java/org/numenta/nupic/ConnectionsTest.java
+++ b/src/test/java/org/numenta/nupic/ConnectionsTest.java
@@ -27,8 +27,23 @@
import org.numenta.nupic.util.ArrayUtils;
import org.numenta.nupic.util.MersenneTwister;
+import com.cedarsoftware.util.DeepEquals;
+
public class ConnectionsTest {
+ @Test
+ public void testCopy() {
+ Parameters retVal = Parameters.getTemporalDefaultParameters();
+ retVal.set(KEY.COLUMN_DIMENSIONS, new int[] { 32 });
+ retVal.set(KEY.CELLS_PER_COLUMN, 4);
+
+ Connections connections = new Connections();
+
+ retVal.apply(connections);
+ TemporalMemory.init(connections);
+
+ assertTrue(DeepEquals.deepEquals(connections, connections.copy()));
+ }
@Test
public void testCreateSegment() {
@@ -574,7 +589,7 @@ public void testGetPrintString() {
TemporalMemory.init(con);
String output = con.getPrintString();
- assertEquals(1370, output.length());
+ assertEquals(1371, output.length());
Set fieldSet = Parameters.getEncoderDefaultParameters().keys().stream().
map(k -> k.getFieldName()).collect(Collectors.toCollection(LinkedHashSet::new));
diff --git a/src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java b/src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java
index 50de1e5f..06bf901b 100644
--- a/src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java
+++ b/src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java
@@ -98,7 +98,7 @@ private void initSP() {
parameters.apply(mem);
sp.init(mem);
}
-
+
@Test
public void confirmSPConstruction() {
setupParameters();
diff --git a/src/test/java/org/numenta/nupic/network/NetworkTest.java b/src/test/java/org/numenta/nupic/network/NetworkTest.java
index 21293cd3..0c29dc98 100644
--- a/src/test/java/org/numenta/nupic/network/NetworkTest.java
+++ b/src/test/java/org/numenta/nupic/network/NetworkTest.java
@@ -548,6 +548,7 @@ public void testBasicNetworkRunAWhileThenHalt() {
@Test
public void testRegionHierarchies() {
Parameters p = NetworkTestHarness.getParameters();
+ p.setPotentialRadius(16);
p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
p.set(KEY.RANDOM, new MersenneTwister(42));
@@ -1007,6 +1008,29 @@ public void testObservableWithCoordinateEncoder_NEGATIVE() {
assertTrue(hasErrors(tester));
}
+ @Test
+ public void testPotentialRadiusFollowsInputWidth() {
+ Parameters p = NetworkTestHarness.getParameters();
+ p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams());
+ p.set(KEY.INPUT_DIMENSIONS, new int[] { 200 });
+ p.set(KEY.RANDOM, new MersenneTwister(42));
+
+ Network network = Network.create("test network", p)
+ .add(Network.createRegion("r1")
+ .add(Network.createLayer("2", p)
+ .add(Anomaly.create())
+ .add(new TemporalMemory())
+ .add(new SpatialPooler())
+ .close()));
+
+ Region r1 = network.lookup("r1");
+ Layer> layer2 = r1.lookup("2");
+
+ int width = layer2.calculateInputWidth();
+ assertEquals(200, width);
+ assertEquals(200, layer2.getConnections().getPotentialRadius());
+ }
+
///////////////////////////////////////////////////////////////////////////////////
// Tests of Calculate Input Width for inter-regional and inter-layer calcs //
///////////////////////////////////////////////////////////////////////////////////
@@ -1063,7 +1087,6 @@ public void testCalculateInputWidth_NoPrevLayer_UpstreamRegion_without_TM() {
int width = layer2.calculateInputWidth();
assertEquals(2048, width);
-
}
@Test
@@ -1077,7 +1100,6 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andTM() {
.add(Network.createLayer("2", p)
.add(Anomaly.create())
.add(new TemporalMemory())
- //.add(new SpatialPooler())
.close()));
Region r1 = network.lookup("r1");
@@ -1098,7 +1120,7 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andSPTM() {
.add(Network.createLayer("2", p)
.add(Anomaly.create())
.add(new TemporalMemory())
- .add(new SpatialPooler())
+ .add(new SpatialPooler())
.close()));
Region r1 = network.lookup("r1");
@@ -1106,6 +1128,7 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andSPTM() {
int width = layer2.calculateInputWidth();
assertEquals(8, width);
+ assertEquals(8, layer2.getConnections().getPotentialRadius());
}
@Test
@@ -1126,6 +1149,7 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andNoTM() {
int width = layer2.calculateInputWidth();
assertEquals(8, width);
+ assertEquals(8, layer2.getConnections().getPotentialRadius());
}
@Test
diff --git a/src/test/java/org/numenta/nupic/network/NetworkTestHarness.java b/src/test/java/org/numenta/nupic/network/NetworkTestHarness.java
index 15e45dad..37bf31c1 100644
--- a/src/test/java/org/numenta/nupic/network/NetworkTestHarness.java
+++ b/src/test/java/org/numenta/nupic/network/NetworkTestHarness.java
@@ -218,7 +218,7 @@ public static Parameters getParameters() {
parameters.set(KEY.CELLS_PER_COLUMN, 6);
//SpatialPooler specific
- parameters.set(KEY.POTENTIAL_RADIUS, 12);//3
+ parameters.set(KEY.POTENTIAL_RADIUS, -1);//3
parameters.set(KEY.POTENTIAL_PCT, 0.5);//0.5
parameters.set(KEY.GLOBAL_INHIBITION, false);
parameters.set(KEY.LOCAL_AREA_DENSITY, -1.0);