Skip to content

Commit

Permalink
WIP #982 Matlab file read/write for double and boolean matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
axkr committed May 3, 2024
1 parent 1d13711 commit af81dec
Show file tree
Hide file tree
Showing 10 changed files with 391 additions and 35 deletions.
4 changes: 4 additions & 0 deletions symja_android_library/matheclipse-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@
<groupId>de.labathome</groupId>
<artifactId>AdaptiveQuadrature</artifactId>
</dependency>
<dependency>
<groupId>us.hebi.matlab.mat</groupId>
<artifactId>mfl-core</artifactId>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package org.matheclipse.core.convert.matlab;

import static us.hebi.matlab.mat.format.Mat5WriteUtil.computeArrayHeaderSize;
import static us.hebi.matlab.mat.format.Mat5WriteUtil.writeArrayHeader;
import static us.hebi.matlab.mat.format.Mat5WriteUtil.writeMatrixTag;
import java.io.IOException;
import org.hipparchus.linear.AnyMatrix;
import us.hebi.matlab.mat.format.Mat5;
import us.hebi.matlab.mat.format.Mat5Serializable;
import us.hebi.matlab.mat.types.AbstractArray;
import us.hebi.matlab.mat.types.Sink;

/**
* Serializes a Symja Matrix into a MAT 5 file that can be read by MATLAB
*
* Note that implementing 'Mat5Attributes' lets us get around the overhead of implementing the
* Matrix / Sparse interfaces, or alternatively writing the header manually.
*
*/
abstract class AbstractMatrixWrapper<M extends AnyMatrix> extends AbstractArray
implements Mat5Serializable, Mat5Serializable.Mat5Attributes {

protected final M matrix;

protected AbstractMatrixWrapper(M matrix) {
super(Mat5.dims(matrix.getRowDimension(), matrix.getColumnDimension()));
this.matrix = matrix;
}

@Override
public void close() throws IOException {
}

@Override
public int[] getDimensions() {
dims[0] = matrix.getRowDimension();
dims[1] = matrix.getColumnDimension();
return dims;
}

protected abstract int getMat5DataSize();

@Override
public int getMat5Size(String name) {
return Mat5.MATRIX_TAG_SIZE
+ computeArrayHeaderSize(name, this)
+ getMat5DataSize();
}

@Override
public int getNzMax() {
return 0;
}

@Override
public boolean isComplex() {
return false;
}

@Override
public boolean isLogical() {
return false;
}

@Override
protected boolean subEqualsGuaranteedSameClass(Object otherGuaranteedSameClass) {
AnyMatrixWrapper other = (AnyMatrixWrapper) otherGuaranteedSameClass;
return other.matrix.equals(matrix);
}

@Override
protected int subHashCode() {
return matrix.hashCode();
}

@Override
public void writeMat5(String name, boolean isGlobal, Sink sink) throws IOException {
writeMatrixTag(name, this, sink);
writeArrayHeader(name, isGlobal, this, sink);
writeMat5Data(sink);
}

/**
* Writes data part in column-major order
*
* @param sink
* @throws IOException
*/
protected abstract void writeMat5Data(Sink sink) throws IOException;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package org.matheclipse.core.convert.matlab;

import java.io.IOException;
import java.io.NotSerializableException;
import org.hipparchus.linear.AnyMatrix;
import org.hipparchus.linear.RealMatrix;
import org.matheclipse.core.eval.exception.SymjaMathException;
import org.matheclipse.core.expression.ASTRealMatrix;
import org.matheclipse.core.interfaces.IAST;
import us.hebi.matlab.mat.format.Mat5Type;
import us.hebi.matlab.mat.types.MatlabType;
import us.hebi.matlab.mat.types.Sink;

class AnyMatrixWrapper extends AbstractMatrixWrapper<AnyMatrix> {

AnyMatrixWrapper(AnyMatrix matrix) {
super(matrix);
}

@Override
protected int getMat5DataSize() {
return Mat5Type.Double
.computeSerializedSize(matrix.getRowDimension() * matrix.getColumnDimension());
}

@Override
public MatlabType getType() {
return MatlabType.Double;
}

@Override
protected void writeMat5Data(Sink sink) throws IOException {
// Real data in column major format
if (matrix instanceof ASTRealMatrix) {
ASTRealMatrix astMatrix = (ASTRealMatrix) matrix;
RealMatrix realMatrix = astMatrix.getRealMatrix();
int rows = realMatrix.getRowDimension();
int columns = realMatrix.getColumnDimension();
int getNumElements = rows * columns;

Mat5Type.Double.writeTag(getNumElements, sink);
for (int col = 0; col < rows; col++) {
for (int row = 0; row < columns; row++) {
sink.writeDouble(realMatrix.getEntry(row, col));
}
}
Mat5Type.Double.writePadding(getNumElements, sink);
} else if (matrix instanceof IAST) {
try {
IAST astMatrix = (IAST) matrix;
int rows = astMatrix.getRowDimension();
int columns = astMatrix.getColumnDimension();
int getNumElements = rows * columns;

Mat5Type.Double.writeTag(getNumElements, sink);
for (int col = 0; col < rows; col++) {
for (int row = 0; row < columns; row++) {
sink.writeDouble(astMatrix.getPart(row + 1, col + 1).evalf());
}
}
Mat5Type.Double.writePadding(getNumElements, sink);
} catch (SymjaMathException sme) {
throw new NotSerializableException();
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package org.matheclipse.core.convert.matlab;

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import org.hipparchus.linear.AnyMatrix;
import org.hipparchus.linear.Array2DRowRealMatrix;
import org.hipparchus.linear.RealMatrix;
import org.matheclipse.core.expression.ASTRealMatrix;
import org.matheclipse.core.expression.F;
import org.matheclipse.core.interfaces.IAST;
import org.matheclipse.core.interfaces.IASTAppendable;
import org.matheclipse.core.interfaces.IExpr;
import us.hebi.matlab.mat.format.Mat5;
import us.hebi.matlab.mat.format.Mat5File;
import us.hebi.matlab.mat.types.AbstractMatrixBase;
import us.hebi.matlab.mat.types.Array;
import us.hebi.matlab.mat.types.MatFile;
import us.hebi.matlab.mat.types.MatlabType;
import us.hebi.matlab.mat.types.Matrix;
import us.hebi.matlab.mat.types.Source;
import us.hebi.matlab.mat.types.Sources;

/**
* Matlab file format conversion to Symja
*/
public class Mat5Symja {

private static AnyMatrix convertToAnyMatrix(AbstractMatrixBase input,
Class<? extends AnyMatrix> clazz) {
final int rows = input.getNumRows();
final int cols = input.getNumCols();
if (clazz.isAssignableFrom(ASTRealMatrix.class)) {
RealMatrix realMatrix = convertToArray2DRowRealMatrix(input, rows, cols);
return new ASTRealMatrix(realMatrix, false);
} else if (clazz.isAssignableFrom(IAST.class)) {
IASTAppendable astMatrix = F.ListAlloc(rows);
for (int i = 0; i < rows; i++) {
astMatrix.append(F.ListAlloc(cols));
}
for (int col = 0; col < cols; col++) {
for (int row = 0; row < rows; row++) {
astMatrix.setPart(input.getDouble(row, col), row, col);
}
}
} else if (clazz.isAssignableFrom(RealMatrix.class)) {
return convertToArray2DRowRealMatrix(input, rows, cols);
}
return F.NIL;
}

private static RealMatrix convertToArray2DRowRealMatrix(Matrix input, final int rows,
final int cols) {
RealMatrix realMatrix = new Array2DRowRealMatrix(rows, cols);
for (int col = 0; col < cols; col++) {
for (int row = 0; row < rows; row++) {
realMatrix.setEntry(row, col, input.getDouble(row, col));
}
}
return realMatrix;
}

public static IAST getTensor(AbstractMatrixBase baseMatrix) {
int[] dimensions = baseMatrix.getDimensions();
if (dimensions.length == 2) {
if (baseMatrix.getType() == MatlabType.Double) {
return (ASTRealMatrix) Mat5Symja.convertToAnyMatrix(baseMatrix, ASTRealMatrix.class);
}
if (baseMatrix.getType() == MatlabType.Sparse) {
return (IAST) Mat5Symja.convertToAnyMatrix(baseMatrix, IAST.class);
}
}
if (baseMatrix.getType() == MatlabType.Sparse) {
return F.NIL;
// int[] indices = new int[dimensions.length];
// final int size = dimensions[0];
// ISparseArray sparse =
// F.sparseArray(F.List(F.Rule(F.List(1, 2, 3), F.b), F.Rule(F.List(1, 4, 5), F.a)));
// IASTAppendable result = F.ListAlloc();
// for (int i = 0; i < size; i++) {
// indices[0] = i;
// getSparseRecursive(baseMatrix, dimensions, indices, 1, result);
// }
// return result;
}
int[] indices = new int[dimensions.length];
final int size = dimensions[0];
IASTAppendable result = F.ListAlloc(size);
for (int i = 0; i < size; i++) {
indices[0] = i;
getTensorRecursive(baseMatrix, dimensions, indices, 1, result);
}
return result;
}

private static void getSparseRecursive(AbstractMatrixBase baseMatrix, int[] dimensions,
int[] indices, int indexCounter, IASTAppendable result) {
int newCounter = indexCounter + 1;
if (indexCounter == dimensions.length) {
MatlabType type = baseMatrix.getType();
switch (type) {
case UInt8:
result.append(baseMatrix.getBoolean(indices));
return;
case Double:
result.append(baseMatrix.getDouble(indices));
return;
case Single:
result.append(baseMatrix.getFloat(indices));
return;
case Sparse:
double d = baseMatrix.getDouble(indices);
result.append(d);
return;
}
return;
}
final int size = dimensions[indexCounter];
IASTAppendable subRow = F.ListAlloc();
for (int i = 0; i < size; i++) {
indices[indexCounter] = i;
getTensorRecursive(baseMatrix, dimensions, indices, newCounter, subRow);
}
result.append(subRow);
}

private static void getTensorRecursive(AbstractMatrixBase baseMatrix, int[] dimensions,
int[] indices, int indexCounter, IASTAppendable result) {
int newCounter = indexCounter + 1;
if (indexCounter == dimensions.length) {
MatlabType type = baseMatrix.getType();
switch (type) {
case UInt8:
result.append(baseMatrix.getBoolean(indices));
return;
case Double:
result.append(baseMatrix.getDouble(indices));
return;
case Single:
result.append(baseMatrix.getFloat(indices));
return;
}
return;
}
final int size = dimensions[indexCounter];
IASTAppendable subRow = F.ListAlloc(size);
for (int i = 0; i < size; i++) {
indices[indexCounter] = i;
getTensorRecursive(baseMatrix, dimensions, indices, newCounter, subRow);
}
result.append(subRow);
}

public static IExpr importMAT(InputStream inputStream, String inputName)
throws IOException, AssertionError {
ByteBuffer buffer = ByteBuffer.allocate(inputStream.available());
int bytes = inputStream.read(buffer.array());
if (bytes != buffer.array().length) {
throw new AssertionError("Could not read full contents of " + inputName);
}
try (Source source = Sources.wrap(buffer)) {
Mat5File mat = Mat5.newReader(source)//
.setReducedHeader(false)//
.readMat();
System.out.println(mat.toString());
for (MatFile.Entry entry : mat.getEntries()) {
// String name = entry.getName();
Array value = entry.getValue();
if (value instanceof AbstractMatrixBase) {
return getTensor((AbstractMatrixBase) value);
}
}
}
return F.NIL;
}
}
4 changes: 0 additions & 4 deletions symja_android_library/matheclipse-io/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-netty</artifactId>
</dependency>
<dependency>
<groupId>us.hebi.matlab.mat</groupId>
<artifactId>mfl-core</artifactId>
</dependency>
<!-- logging dependencies -->

<dependency>
Expand Down
Loading

0 comments on commit af81dec

Please sign in to comment.