Skip to content

Commit

Permalink
improved tensors API
Browse files Browse the repository at this point in the history
  • Loading branch information
vpenades committed Apr 30, 2024
1 parent 28f3a7c commit c5d7fef
Show file tree
Hide file tree
Showing 9 changed files with 848 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
<PackageId>InteropTypes.Tensors.ONNX</PackageId>

<!-- packed as source code package -->
<IsPackable>false</IsPackable>

<IsPackable>false</IsPackable>
<IsPackableAsSources>true</IsPackableAsSources>
</PropertyGroup>

Expand Down
103 changes: 103 additions & 0 deletions src/InteropTypes.Tensors.OnnxRuntime/NamedValueExtensions.pp.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) InteropTypes 2024 Vicente Penades

#nullable disable

using System;
using System.Linq;
using System.Numerics;
using InteropTypes.Tensors;

using ONNX = Microsoft.ML.OnnxRuntime;
using ONNXTENSORS = Microsoft.ML.OnnxRuntime.Tensors;

using NAMEDVALUE = Microsoft.ML.OnnxRuntime.NamedOnnxValue;


#if INTEROPTYPES_USEINTEROPNAMESPACE
namespace InteropTypes.Tensors
#elif INTEROPTYPES_TENSORS_USEONNXRUNTIMENAMESPACE
namespace Microsoft.ML.OnnxRuntime
#else
namespace $rootnamespace$
#endif
{
static partial class InteropTensorsForOnnxRuntime
{
public static Type GetElementType(this ONNXTENSORS.TensorElementType etype)
{
switch (etype)
{
case ONNXTENSORS.TensorElementType.Bool: return typeof(Boolean);

case ONNXTENSORS.TensorElementType.Int8: return typeof(SByte);
case ONNXTENSORS.TensorElementType.UInt8: return typeof(Byte);

case ONNXTENSORS.TensorElementType.Int16: return typeof(Int16);
case ONNXTENSORS.TensorElementType.UInt16: return typeof(UInt16);

case ONNXTENSORS.TensorElementType.Int32: return typeof(Int32);
case ONNXTENSORS.TensorElementType.UInt32: return typeof(UInt32);

case ONNXTENSORS.TensorElementType.Int64: return typeof(Int64);
case ONNXTENSORS.TensorElementType.UInt64: return typeof(UInt64);

case ONNXTENSORS.TensorElementType.Float16: return typeof(Half);
case ONNXTENSORS.TensorElementType.Float: return typeof(Single);
case ONNXTENSORS.TensorElementType.Double: return typeof(Double);

case ONNXTENSORS.TensorElementType.String: return typeof(String);

case ONNXTENSORS.TensorElementType.Complex64: return typeof(Complex); // needs checking

default: throw new NotImplementedException(etype.ToString());
}
}

public static ONNXTENSORS.DenseTensor<T> AsDenseTensor<T>(this NAMEDVALUE nvalue)
{
if (nvalue.Value is ONNXTENSORS.DenseTensor<T> dtensor) return dtensor;

return nvalue.AsTensor<T>().ToDenseTensor();
}

public static NAMEDVALUE CreateNamedTensor(this ONNX.NodeMetadata metadata, string name, ReadOnlySpan<int> dimensions)
{
if (dimensions.IsEmpty)
{
dimensions = metadata.Dimensions;

if (metadata.Dimensions.Any(dim => dim <= 0))
{
var array = dimensions.ToArray();
for (int i = 0; i < dimensions.Length; ++i)
{
if (array[i] <= 0) array[i] = 1;
}
dimensions = array;
}
}

if (metadata.ElementType == typeof(Boolean)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<Boolean>(dimensions));
if (metadata.ElementType == typeof(Char)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<Char>(dimensions));

if (metadata.ElementType == typeof(SByte)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<SByte>(dimensions));
if (metadata.ElementType == typeof(Byte)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<Byte>(dimensions));

if (metadata.ElementType == typeof(Int16)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<Int16>(dimensions));
if (metadata.ElementType == typeof(UInt16)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<UInt16>(dimensions));

if (metadata.ElementType == typeof(Int32)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<Int32>(dimensions));
if (metadata.ElementType == typeof(UInt32)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<UInt32>(dimensions));

if (metadata.ElementType == typeof(Int64)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<Int64>(dimensions));
if (metadata.ElementType == typeof(UInt64)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<UInt64>(dimensions));

if (metadata.ElementType == typeof(Half)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<Half>(dimensions));
if (metadata.ElementType == typeof(Single)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<Single>(dimensions));
if (metadata.ElementType == typeof(Double)) return NAMEDVALUE.CreateFromTensor(name, new ONNXTENSORS.DenseTensor<Double>(dimensions));

throw new NotImplementedException();
}
}

}
135 changes: 91 additions & 44 deletions src/InteropTypes.Tensors.OnnxRuntime/OrtValueExtensions.pp.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
using System.Runtime.CompilerServices;

using InteropTypes.Tensors;
using Microsoft.ML.OnnxRuntime;

using ORTVALUE = Microsoft.ML.OnnxRuntime.OrtValue;
using ORTVALUEINFO = Microsoft.ML.OnnxRuntime.OrtTensorTypeAndShapeInfo;

#if INTEROPTYPES_USEINTEROPNAMESPACE
namespace InteropTypes.Tensors
Expand All @@ -21,159 +23,204 @@ namespace $rootnamespace$
{
static partial class InteropTensorsForOnnxRuntime
{
public static ReadOnlySpanTensor1<T> AsReadOnlySpanTensor1<T>(this OrtValue value)
public static ReadOnlySpanTensor1<T> AsReadOnlySpanTensor1<T>(this ORTVALUE value)
where T: unmanaged
{
return TryGetReadOnlySpanTensor1<T>(value, out var tensor) ? tensor : throw new ArgumentException("invalid tensor", nameof(OrtValue));
return TryGetReadOnlySpanTensor1<T>(value, out var tensor)
? tensor
: throw _GetTensorInfoException<T>(value, nameof(value));
}

public static bool TryGetReadOnlySpanTensor1<T>(this OrtValue value, out ReadOnlySpanTensor1<T> tensor)
public static bool TryGetReadOnlySpanTensor1<T>(this ORTVALUE value, out ReadOnlySpanTensor1<T> tensor)
where T:unmanaged
{
if (!_TryGetTypeAndShape(value, 1, out var info)) { tensor = default; return false; }
if (!_TryGetTensorInfo<T>(value, out var info)) { tensor = default; return false; }

var data = value.GetTensorDataAsSpan<T>();
tensor = new ReadOnlySpanTensor1<T>(data, (int)info.Shape[0]);
var dims = TensorSize1.FromAny(info.Shape.AsSpan());

tensor = new ReadOnlySpanTensor1<T>(data, dims);
return true;
}

public static SpanTensor1<T> AsSpanTensor1<T>(this OrtValue value)
public static SpanTensor1<T> AsSpanTensor1<T>(this ORTVALUE value)
where T : unmanaged
{
return TryGetSpanTensor1<T>(value, out var tensor) ? tensor : throw new ArgumentException("invalid tensor", nameof(OrtValue));
return TryGetSpanTensor1<T>(value, out var tensor)
? tensor
: throw _GetTensorInfoException<T>(value, nameof(value));
}

public static bool TryGetSpanTensor1<T>(this OrtValue value, out SpanTensor1<T> tensor)
public static bool TryGetSpanTensor1<T>(this ORTVALUE value, out SpanTensor1<T> tensor)
where T : unmanaged
{
if (!_TryGetTypeAndShape(value, 1, out var info)) { tensor = default; return false; }
if (!_TryGetTensorInfo<T>(value, out var info)) { tensor = default; return false; }

var data = value.GetTensorMutableDataAsSpan<T>();
var dims = TensorSize1.FromAny(info.Shape.AsSpan());

tensor = new SpanTensor1<T>(data, (int)info.Shape[0]);
tensor = new SpanTensor1<T>(data, dims);

return true;
}

public static ReadOnlySpanTensor2<T> AsReadOnlySpanTensor2<T>(this OrtValue value)
public static ReadOnlySpanTensor2<T> AsReadOnlySpanTensor2<T>(this ORTVALUE value)
where T : unmanaged
{
return TryGetReadOnlySpanTensor2<T>(value, out var tensor) ? tensor : throw new ArgumentException("invalid tensor", nameof(OrtValue));
return TryGetReadOnlySpanTensor2<T>(value, out var tensor)
? tensor
: throw _GetTensorInfoException<T>(value, nameof(value));
}

public static bool TryGetReadOnlySpanTensor2<T>(this OrtValue value, out ReadOnlySpanTensor2<T> tensor)
public static bool TryGetReadOnlySpanTensor2<T>(this ORTVALUE value, out ReadOnlySpanTensor2<T> tensor)
where T : unmanaged
{
if (!_TryGetTypeAndShape(value, 2, out var info)) { tensor = default; return false; }
if (!_TryGetTensorInfo<T>(value, out var info)) { tensor = default; return false; }

var data = value.GetTensorDataAsSpan<T>();
var dims = TensorSize2.FromAny(info.Shape.AsSpan());

tensor = new ReadOnlySpanTensor2<T>(data, (int)info.Shape[0], (int)info.Shape[1]);
tensor = new ReadOnlySpanTensor2<T>(data, dims);

return true;
}

public static SpanTensor2<T> AsSpanTensor2<T>(this OrtValue value)
public static SpanTensor2<T> AsSpanTensor2<T>(this ORTVALUE value)
where T : unmanaged
{
return TryGetSpanTensor2<T>(value, out var tensor) ? tensor : throw new ArgumentException("invalid tensor", nameof(OrtValue));
return TryGetSpanTensor2<T>(value, out var tensor)
? tensor
: throw _GetTensorInfoException<T>(value, nameof(value));
}

public static bool TryGetSpanTensor2<T>(this OrtValue value, out SpanTensor2<T> tensor)
public static bool TryGetSpanTensor2<T>(this ORTVALUE value, out SpanTensor2<T> tensor)
where T : unmanaged
{
if (!_TryGetTypeAndShape(value, 2, out var info)) { tensor = default; return false; }
if (!_TryGetTensorInfo<T>(value, out var info)) { tensor = default; return false; }

var data = value.GetTensorMutableDataAsSpan<T>();
var dims = TensorSize2.FromAny(info.Shape.AsSpan());

tensor = new SpanTensor2<T>(data, (int)info.Shape[0], (int)info.Shape[1]);
tensor = new SpanTensor2<T>(data, dims);

return true;
}

public static ReadOnlySpanTensor3<T> AsReadOnlySpanTensor3<T>(this OrtValue value)
public static ReadOnlySpanTensor3<T> AsReadOnlySpanTensor3<T>(this ORTVALUE value)
where T : unmanaged
{
return TryGetReadOnlySpanTensor3<T>(value, out var tensor) ? tensor : throw new ArgumentException("invalid tensor", nameof(OrtValue));
return TryGetReadOnlySpanTensor3<T>(value, out var tensor)
? tensor
: throw _GetTensorInfoException<T>(value, nameof(value));
}

public static bool TryGetReadOnlySpanTensor3<T>(this OrtValue value, out ReadOnlySpanTensor3<T> tensor)
public static bool TryGetReadOnlySpanTensor3<T>(this ORTVALUE value, out ReadOnlySpanTensor3<T> tensor)
where T : unmanaged
{
if (!_TryGetTypeAndShape(value, 3, out var info)) { tensor = default; return false; }
if (!_TryGetTensorInfo<T>(value, out var info)) { tensor = default; return false; }

var data = value.GetTensorDataAsSpan<T>();
var dims = TensorSize3.FromAny(info.Shape.AsSpan());

tensor = new ReadOnlySpanTensor3<T>(data, (int)info.Shape[0], (int)info.Shape[1], (int)info.Shape[2]);
tensor = new ReadOnlySpanTensor3<T>(data, dims);

return true;
}

public static SpanTensor3<T> AsSpanTensor3<T>(this OrtValue value)
public static SpanTensor3<T> AsSpanTensor3<T>(this ORTVALUE value)
where T : unmanaged
{
return TryGetSpanTensor3<T>(value, out var tensor) ? tensor : throw new ArgumentException("invalid tensor", nameof(OrtValue));
return TryGetSpanTensor3<T>(value, out var tensor)
? tensor
: throw _GetTensorInfoException<T>(value, nameof(value));
}

public static bool TryGetSpanTensor3<T>(this OrtValue value, out SpanTensor3<T> tensor)
public static bool TryGetSpanTensor3<T>(this ORTVALUE value, out SpanTensor3<T> tensor)
where T : unmanaged
{
if (!_TryGetTypeAndShape(value, 3, out var info)) { tensor = default; return false; }
if (!_TryGetTensorInfo<T>(value, out var info)) { tensor = default; return false; }

var data = value.GetTensorMutableDataAsSpan<T>();
var dims = TensorSize3.FromAny(info.Shape.AsSpan());

tensor = new SpanTensor3<T>(data, (int)info.Shape[0], (int)info.Shape[1], (int)info.Shape[2]);
tensor = new SpanTensor3<T>(data, dims);

return true;
}

public static ReadOnlySpanTensor4<T> AsReadOnlySpanTensor4<T>(this OrtValue value)
public static ReadOnlySpanTensor4<T> AsReadOnlySpanTensor4<T>(this ORTVALUE value)
where T : unmanaged
{
return TryGetReadOnlySpanTensor4<T>(value, out var tensor) ? tensor : throw new ArgumentException("invalid tensor", nameof(OrtValue));
return TryGetReadOnlySpanTensor4<T>(value, out var tensor)
? tensor
: throw _GetTensorInfoException<T>(value, nameof(value));
}

public static bool TryGetReadOnlySpanTensor4<T>(this OrtValue value, out ReadOnlySpanTensor4<T> tensor)
public static bool TryGetReadOnlySpanTensor4<T>(this ORTVALUE value, out ReadOnlySpanTensor4<T> tensor)
where T : unmanaged
{
if (!_TryGetTypeAndShape(value, 4, out var info)) { tensor = default; return false; }
if (!_TryGetTensorInfo<T>(value, out var info)) { tensor = default; return false; }

var data = value.GetTensorDataAsSpan<T>();
var dims = TensorSize4.FromAny(info.Shape.AsSpan());

tensor = new ReadOnlySpanTensor4<T>(data, (int)info.Shape[0], (int)info.Shape[1], (int)info.Shape[2], (int)info.Shape[3]);
tensor = new ReadOnlySpanTensor4<T>(data, dims);

return true;
}

public static SpanTensor4<T> AsSpanTensor4<T>(this OrtValue value)
public static SpanTensor4<T> AsSpanTensor4<T>(this ORTVALUE value)
where T : unmanaged
{
return TryGetSpanTensor4<T>(value, out var tensor) ? tensor : throw new ArgumentException("invalid tensor", nameof(OrtValue));
return TryGetSpanTensor4<T>(value, out var tensor)
? tensor
: throw _GetTensorInfoException<T>(value, nameof(value));
}

public static bool TryGetSpanTensor4<T>(this OrtValue value, out SpanTensor4<T> tensor)
public static bool TryGetSpanTensor4<T>(this ORTVALUE value, out SpanTensor4<T> tensor)
where T : unmanaged
{
if (!_TryGetTypeAndShape(value, 4, out var info)) { tensor = default; return false; }
if (!_TryGetTensorInfo<T>(value, out var info)) { tensor = default; return false; }

var data = value.GetTensorMutableDataAsSpan<T>();
var dims = TensorSize4.FromAny(info.Shape.AsSpan());

tensor = new SpanTensor4<T>(data, (int)info.Shape[0], (int)info.Shape[1], (int)info.Shape[2], (int)info.Shape[3]);
tensor = new SpanTensor4<T>(data, dims);

return true;
}

private static bool _TryGetTypeAndShape(OrtValue value, int dims, out OrtTensorTypeAndShapeInfo info)
private static bool _TryGetTensorInfo<T>(ORTVALUE value, out ORTVALUEINFO info)
{
info = default;
if (value == null) return false;
if (!value.IsTensor) return false;
if (value.IsSparseTensor) return false;
if (value.IsSparseTensor) return false;

info = value.GetTensorTypeAndShape();
if (info.Shape.Length != dims) return false;

var dataType = info.ElementDataType.GetElementType();

if (dataType != typeof(T)) return false;

return true;
}

private static Exception _GetTensorInfoException<T>(ORTVALUE value, string name)
{

if (value == null) return new ArgumentNullException(name);
if (!value.IsTensor) return new ArgumentException("Not a tensor", name);
if (value.IsSparseTensor) return new ArgumentException("Not dense tensor", name);

var info = value.GetTensorTypeAndShape();

var dataType = info.ElementDataType.GetElementType();

if (dataType != typeof(T)) return new ArgumentException($"Type mismatch, expected: {info.ElementDataType} but was {typeof(T).Name}", name);

return new ArgumentException("Error", name);
}
}

}
Loading

0 comments on commit c5d7fef

Please sign in to comment.