diff --git a/src/main/java/org/numenta/nupic/Connections.java b/src/main/java/org/numenta/nupic/Connections.java index 35f60973..91bd06d4 100644 --- a/src/main/java/org/numenta/nupic/Connections.java +++ b/src/main/java/org/numenta/nupic/Connections.java @@ -48,7 +48,11 @@ import org.numenta.nupic.model.ProximalDendrite; import org.numenta.nupic.model.Segment; import org.numenta.nupic.model.Synapse; +import org.numenta.nupic.network.Persistence; +import org.numenta.nupic.network.PersistenceAPI; +import org.numenta.nupic.serialize.SerialConfig; import org.numenta.nupic.util.AbstractSparseBinaryMatrix; +import org.numenta.nupic.util.ArrayUtils; import org.numenta.nupic.util.FlatMatrix; import org.numenta.nupic.util.SparseMatrix; import org.numenta.nupic.util.SparseObjectMatrix; @@ -69,6 +73,10 @@ public class Connections implements Persistable { private static final double EPSILON = 0.00001; /////////////////////////////////////// Spatial Pooler Vars /////////////////////////////////////////// + /** WARNING: potentialRadius **must** be set to + * the inputWidth if using "globalInhibition" and if not + * using the Network API (which sets this automatically) + */ private int potentialRadius = 16; private double potentialPct = 0.5; private boolean globalInhibition = false; @@ -242,12 +250,25 @@ public class Connections implements Persistable { */ public Connections() {} + /** + * Returns a deep copy of this {@code Connections} object. + * @return a deep copy of this {@code Connections} + */ + public Connections copy() { + PersistenceAPI api = Persistence.get(new SerialConfig()); + byte[] myBytes = api.serializer().serialize(this); + return api.serializer().deSerialize(myBytes); + } + /** * Sets the derived values of the {@link SpatialPooler}'s initialization. */ public void doSpatialPoolerPostInit() { synPermBelowStimulusInc = synPermConnected / 10.0; synPermTrimThreshold = synPermActiveInc / 2.0; + if(potentialRadius == -1) { + potentialRadius = ArrayUtils.product(inputDimensions); + } } ///////////////////////////////////////// @@ -480,6 +501,11 @@ public void setNumColumns(int n) { * parameter defines a square (or hyper square) area: a * column will have a max square potential pool with * sides of length 2 * potentialRadius + 1. + * + * WARNING: potentialRadius **must** be set to + * the inputWidth if using "globalInhibition" and if not + * using the Network API (which sets this automatically) + * * * @param potentialRadius */ @@ -489,11 +515,12 @@ public void setPotentialRadius(int potentialRadius) { /** * Returns the configured potential radius + * * @return the configured potential radius * @see setPotentialRadius */ public int getPotentialRadius() { - return Math.min(numInputs, potentialRadius); + return potentialRadius; } /** diff --git a/src/main/java/org/numenta/nupic/Parameters.java b/src/main/java/org/numenta/nupic/Parameters.java index 32c5315a..d2ed2e15 100644 --- a/src/main/java/org/numenta/nupic/Parameters.java +++ b/src/main/java/org/numenta/nupic/Parameters.java @@ -97,7 +97,7 @@ public class Parameters implements Persistable { //////////// Spatial Pooler Parameters /////////// Map defaultSpatialParams = new ParametersMap(); defaultSpatialParams.put(KEY.INPUT_DIMENSIONS, new int[]{64}); - defaultSpatialParams.put(KEY.POTENTIAL_RADIUS, 16); + defaultSpatialParams.put(KEY.POTENTIAL_RADIUS, -1); defaultSpatialParams.put(KEY.POTENTIAL_PCT, 0.5); defaultSpatialParams.put(KEY.GLOBAL_INHIBITION, false); defaultSpatialParams.put(KEY.INHIBITION_RADIUS, 0); @@ -225,6 +225,10 @@ public static enum KEY { /////////// Spatial Pooler Parameters /////////// INPUT_DIMENSIONS("inputDimensions", int[].class), + /** WARNING: potentialRadius **must** be set to + * the inputWidth if using "globalInhibition" and if not + * using the Network API (which sets this automatically) + */ POTENTIAL_RADIUS("potentialRadius", Integer.class), POTENTIAL_PCT("potentialPct", Double.class), //TODO add range here? GLOBAL_INHIBITION("globalInhibition", Boolean.class), @@ -770,6 +774,11 @@ public void setInputDimensions(int[] inputDimensions) { * parameter defines a square (or hyper square) area: a * column will have a max square potential pool with * sides of length 2 * potentialRadius + 1. + * + * WARNING: potentialRadius **must** be set to + * the inputWidth if using "globalInhibition" and if not + * using the Network API (which sets this automatically) + * * * @param potentialRadius */ diff --git a/src/main/java/org/numenta/nupic/network/Layer.java b/src/main/java/org/numenta/nupic/network/Layer.java index 6044bda1..aa9d6599 100644 --- a/src/main/java/org/numenta/nupic/network/Layer.java +++ b/src/main/java/org/numenta/nupic/network/Layer.java @@ -495,9 +495,7 @@ public Layer close() { params.setInputDimensions(upstreamDims); connections.setInputDimensions(upstreamDims); } else if(parentRegion != null && parentNetwork != null - && parentRegion.equals(parentNetwork.getSensorRegion()) && encoder == null - && spatialPooler != null) { - + && parentRegion.equals(parentNetwork.getSensorRegion()) && encoder == null && spatialPooler != null) { Layer curr = this; while((curr = curr.getPrevious()) != null) { if(curr.getEncoder() != null) { @@ -692,7 +690,7 @@ public Subscription subscribe(final Observer subscriber) { return createSubscription(subscriber); } - + /** * Allows the user to define the {@link Connections} object data structure * to use. Or possibly to share connections between two {@code Layer}s diff --git a/src/test/java/org/numenta/nupic/ConnectionsTest.java b/src/test/java/org/numenta/nupic/ConnectionsTest.java index abdd8564..8cb18589 100644 --- a/src/test/java/org/numenta/nupic/ConnectionsTest.java +++ b/src/test/java/org/numenta/nupic/ConnectionsTest.java @@ -27,8 +27,23 @@ import org.numenta.nupic.util.ArrayUtils; import org.numenta.nupic.util.MersenneTwister; +import com.cedarsoftware.util.DeepEquals; + public class ConnectionsTest { + @Test + public void testCopy() { + Parameters retVal = Parameters.getTemporalDefaultParameters(); + retVal.set(KEY.COLUMN_DIMENSIONS, new int[] { 32 }); + retVal.set(KEY.CELLS_PER_COLUMN, 4); + + Connections connections = new Connections(); + + retVal.apply(connections); + TemporalMemory.init(connections); + + assertTrue(DeepEquals.deepEquals(connections, connections.copy())); + } @Test public void testCreateSegment() { @@ -574,7 +589,7 @@ public void testGetPrintString() { TemporalMemory.init(con); String output = con.getPrintString(); - assertEquals(1370, output.length()); + assertEquals(1371, output.length()); Set fieldSet = Parameters.getEncoderDefaultParameters().keys().stream(). map(k -> k.getFieldName()).collect(Collectors.toCollection(LinkedHashSet::new)); diff --git a/src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java b/src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java index 50de1e5f..06bf901b 100644 --- a/src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java +++ b/src/test/java/org/numenta/nupic/algorithms/SpatialPoolerTest.java @@ -98,7 +98,7 @@ private void initSP() { parameters.apply(mem); sp.init(mem); } - + @Test public void confirmSPConstruction() { setupParameters(); diff --git a/src/test/java/org/numenta/nupic/network/NetworkTest.java b/src/test/java/org/numenta/nupic/network/NetworkTest.java index 21293cd3..0c29dc98 100644 --- a/src/test/java/org/numenta/nupic/network/NetworkTest.java +++ b/src/test/java/org/numenta/nupic/network/NetworkTest.java @@ -548,6 +548,7 @@ public void testBasicNetworkRunAWhileThenHalt() { @Test public void testRegionHierarchies() { Parameters p = NetworkTestHarness.getParameters(); + p.setPotentialRadius(16); p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams()); p.set(KEY.RANDOM, new MersenneTwister(42)); @@ -1007,6 +1008,29 @@ public void testObservableWithCoordinateEncoder_NEGATIVE() { assertTrue(hasErrors(tester)); } + @Test + public void testPotentialRadiusFollowsInputWidth() { + Parameters p = NetworkTestHarness.getParameters(); + p = p.union(NetworkTestHarness.getNetworkDemoTestEncoderParams()); + p.set(KEY.INPUT_DIMENSIONS, new int[] { 200 }); + p.set(KEY.RANDOM, new MersenneTwister(42)); + + Network network = Network.create("test network", p) + .add(Network.createRegion("r1") + .add(Network.createLayer("2", p) + .add(Anomaly.create()) + .add(new TemporalMemory()) + .add(new SpatialPooler()) + .close())); + + Region r1 = network.lookup("r1"); + Layer layer2 = r1.lookup("2"); + + int width = layer2.calculateInputWidth(); + assertEquals(200, width); + assertEquals(200, layer2.getConnections().getPotentialRadius()); + } + /////////////////////////////////////////////////////////////////////////////////// // Tests of Calculate Input Width for inter-regional and inter-layer calcs // /////////////////////////////////////////////////////////////////////////////////// @@ -1063,7 +1087,6 @@ public void testCalculateInputWidth_NoPrevLayer_UpstreamRegion_without_TM() { int width = layer2.calculateInputWidth(); assertEquals(2048, width); - } @Test @@ -1077,7 +1100,6 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andTM() { .add(Network.createLayer("2", p) .add(Anomaly.create()) .add(new TemporalMemory()) - //.add(new SpatialPooler()) .close())); Region r1 = network.lookup("r1"); @@ -1098,7 +1120,7 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andSPTM() { .add(Network.createLayer("2", p) .add(Anomaly.create()) .add(new TemporalMemory()) - .add(new SpatialPooler()) + .add(new SpatialPooler()) .close())); Region r1 = network.lookup("r1"); @@ -1106,6 +1128,7 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andSPTM() { int width = layer2.calculateInputWidth(); assertEquals(8, width); + assertEquals(8, layer2.getConnections().getPotentialRadius()); } @Test @@ -1126,6 +1149,7 @@ public void testCalculateInputWidth_NoPrevLayer_NoPrevRegion_andNoTM() { int width = layer2.calculateInputWidth(); assertEquals(8, width); + assertEquals(8, layer2.getConnections().getPotentialRadius()); } @Test diff --git a/src/test/java/org/numenta/nupic/network/NetworkTestHarness.java b/src/test/java/org/numenta/nupic/network/NetworkTestHarness.java index 15e45dad..37bf31c1 100644 --- a/src/test/java/org/numenta/nupic/network/NetworkTestHarness.java +++ b/src/test/java/org/numenta/nupic/network/NetworkTestHarness.java @@ -218,7 +218,7 @@ public static Parameters getParameters() { parameters.set(KEY.CELLS_PER_COLUMN, 6); //SpatialPooler specific - parameters.set(KEY.POTENTIAL_RADIUS, 12);//3 + parameters.set(KEY.POTENTIAL_RADIUS, -1);//3 parameters.set(KEY.POTENTIAL_PCT, 0.5);//0.5 parameters.set(KEY.GLOBAL_INHIBITION, false); parameters.set(KEY.LOCAL_AREA_DENSITY, -1.0);