Skip to content

Commit

Permalink
Merge pull request #419 from cogmission/slight_persistence_patch
Browse files Browse the repository at this point in the history
Slight persistence patch and HTMObjectIn/Output constructor change.
  • Loading branch information
cogmission committed May 2, 2016
2 parents 95ddd4c + 22ba2d7 commit 4684f8b
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 10 deletions.
5 changes: 3 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ apply plugin: 'eclipse'
apply plugin: 'signing'

group = 'org.numenta'
version = '0.6.7-SNAPSHOT'
version = '0.6.8'
archivesBaseName = 'htm.java'

sourceCompatibility = 1.8
targetCompatibility = 1.8

jar {
manifest {
attributes 'Implementation-Title': 'htm.java', 'Implementation-Version': '0.6.7-SNAPSHOT'
attributes 'Implementation-Title': 'htm.java', 'Implementation-Version': '0.6.8'
}
}

Expand Down Expand Up @@ -126,6 +126,7 @@ uploadArchives {
javadoc.failOnError = false

if(!project.hasProperty('ossrhUsername')) {
println "returning from has Property false"
return
}

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<groupId>org.numenta</groupId>
<artifactId>htm.java</artifactId>
<version>0.6.7-SNAPSHOT</version>
<version>0.6.8</version>
<name>htm.java</name>
<description>The Java version of Numenta's HTM technology</description>

Expand Down
8 changes: 7 additions & 1 deletion src/main/java/org/numenta/nupic/network/Network.java
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,15 @@ public static PALayer<?> createPALayer(String name, Parameters p) {
@SuppressWarnings("unchecked")
@Override
public Network preSerialize() {
if(shouldDoHalt) {
if(shouldDoHalt && isThreadRunning) {
halt();
}else{ // Make sure "close()" has been called on the Network
if(regions.size() == 1) {
this.tail = regions.get(0);
}
tail.close();
}

regions.stream().forEach(r -> r.preSerialize());
return this;
}
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/org/numenta/nupic/serialize/HTMObjectInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import java.io.InputStream;

import org.numenta.nupic.Persistable;
import org.nustaq.serialization.FSTConfiguration;
import org.nustaq.serialization.FSTObjectInput;

public class HTMObjectInput extends FSTObjectInput {
public HTMObjectInput(InputStream in) throws IOException {
super(in);
public HTMObjectInput(InputStream in, FSTConfiguration config) throws IOException {
super(in, config);
}

@SuppressWarnings("rawtypes")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import java.io.OutputStream;

import org.numenta.nupic.Persistable;
import org.nustaq.serialization.FSTConfiguration;
import org.nustaq.serialization.FSTObjectOutput;

public class HTMObjectOutput extends FSTObjectOutput {
public HTMObjectOutput(OutputStream out) {
super(out);
public HTMObjectOutput(OutputStream out, FSTConfiguration config) {
super(out, config);
}

@SuppressWarnings("rawtypes")
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/numenta/nupic/serialize/SerializerCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE
* @throws IOException
*/
public HTMObjectInput getObjectInput(InputStream is) throws IOException {
return new HTMObjectInput(is);
return new HTMObjectInput(is, fastSerialConfig);
}

/**
Expand All @@ -113,7 +113,7 @@ public HTMObjectInput getObjectInput(InputStream is) throws IOException {
* @return the HTMObjectOutput
*/
public <T extends Persistable> HTMObjectOutput getObjectOutput(OutputStream os) {
return new HTMObjectOutput(os);
return new HTMObjectOutput(os, fastSerialConfig);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package org.numenta.nupic.serialize;

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;

import org.junit.Test;
import org.numenta.nupic.Parameters;
import org.numenta.nupic.Parameters.KEY;
import org.numenta.nupic.algorithms.Anomaly;
import org.numenta.nupic.algorithms.SpatialPooler;
import org.numenta.nupic.algorithms.TemporalMemory;
import org.numenta.nupic.network.Network;
import org.numenta.nupic.network.NetworkTestHarness;
import org.numenta.nupic.network.Persistence;
import org.numenta.nupic.network.PublisherSupplier;
import org.numenta.nupic.network.sensor.ObservableSensor;
import org.numenta.nupic.network.sensor.Sensor;
import org.numenta.nupic.network.sensor.SensorParams;
import org.numenta.nupic.network.sensor.SensorParams.Keys;
import org.numenta.nupic.util.FastRandom;


public class HTMObjectInputOutputTest {

@Test
public void testRoundTrip() {
Network network = getLoadedHotGymNetwork();
SerializerCore serializer = Persistence.get().serializer();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
HTMObjectOutput writer = serializer.getObjectOutput(baos);
try {
writer.writeObject(network, Network.class);
writer.flush();
writer.close();
}catch(Exception e) {
fail();
}

byte[] bytes = baos.toByteArray();

ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
try {
HTMObjectInput reader = serializer.getObjectInput(bais);
Network serializedNetwork = (Network)reader.readObject(Network.class);
assertNotNull(serializedNetwork);
assertTrue(serializedNetwork.equals(network));
}catch(Exception e) {
e.printStackTrace();
fail();
}
}

private Network getLoadedHotGymNetwork() {
Parameters p = NetworkTestHarness.getParameters().copy();
p = p.union(NetworkTestHarness.getHotGymTestEncoderParams());
p.setParameterByKey(KEY.RANDOM, new FastRandom(42));

Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
ObservableSensor::create, SensorParams.create(Keys::obs, new Object[] {"name",
PublisherSupplier.builder()
.addHeader("timestamp, consumption")
.addHeader("datetime, float")
.addHeader("B").build() }));

Network network = Network.create("test network", p).add(Network.createRegion("r1")
.add(Network.createLayer("1", p)
.alterParameter(KEY.AUTO_CLASSIFY, true)
.add(Anomaly.create())
.add(new TemporalMemory())
.add(new SpatialPooler())
.add(sensor)));

return network;
}
}

0 comments on commit 4684f8b

Please sign in to comment.