Skip to content

Commit 662a97d

Browse files
authored
Convert MemoryBufferProtocolWrapper to generic (#1906)
* Convert `MemoryBufferProtocolWrapper` to generic * Autodetect type code
1 parent 3f31f45 commit 662a97d

2 files changed

Lines changed: 62 additions & 50 deletions

File tree

src/core/IronPython/Runtime/Binding/ConversionBinder.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,7 @@ private DynamicMetaObject ConvertFromMemoryToBufferProtocol(DynamicMetaObject se
734734
return new DynamicMetaObject(
735735
AstUtils.Convert(
736736
Ast.New(
737-
typeof(MemoryBufferProtocolWrapper).GetConstructor(new Type[] { fromType }),
737+
typeof(MemoryBufferProtocolWrapper<byte>).GetConstructor([fromType]),
738738
AstUtils.Convert(self.Expression, fromType)
739739
),
740740
typeof(IBufferProtocol)

src/core/IronPython/Runtime/ConversionWrappers.cs

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using System.Collections.Generic;
1111
using System.Diagnostics.CodeAnalysis;
1212
using System.Numerics;
13+
using System.Runtime.InteropServices;
1314

1415
namespace IronPython.Runtime {
1516

@@ -376,79 +377,90 @@ IEnumerator IEnumerable.GetEnumerator() {
376377
#endregion
377378
}
378379

379-
public sealed class MemoryBufferWrapper : IPythonBuffer {
380-
private readonly ReadOnlyMemory<byte> _rom;
381-
private readonly Memory<byte>? _memory;
382-
private readonly BufferFlags _flags;
380+
public sealed class MemoryBufferProtocolWrapper<T> : IBufferProtocol where T : unmanaged {
381+
private readonly ReadOnlyMemory<T> _rom;
382+
private readonly Memory<T>? _memory;
383+
private readonly char _format;
383384

384-
public MemoryBufferWrapper(ReadOnlyMemory<byte> memory, BufferFlags flags) {
385+
public MemoryBufferProtocolWrapper(ReadOnlyMemory<T> memory) {
385386
_rom = memory;
386387
_memory = null;
387-
_flags = flags;
388+
_format = GetFormatChar();
388389
}
389390

390-
public MemoryBufferWrapper(Memory<byte> memory, BufferFlags flags) {
391+
public MemoryBufferProtocolWrapper(Memory<T> memory) {
391392
_rom = memory;
392393
_memory = memory;
393-
_flags = flags;
394+
_format = GetFormatChar();
394395
}
395396

396-
public void Dispose() { }
397-
398-
public object Object => _memory ?? _rom;
399-
400-
public bool IsReadOnly => !_memory.HasValue;
397+
public IPythonBuffer? GetBuffer(BufferFlags flags, bool throwOnError) {
398+
if (flags.HasFlag(BufferFlags.Writable) && !_memory.HasValue) {
399+
if (throwOnError) {
400+
throw Operations.PythonOps.BufferError("ReadOnlyMemory is not writable.");
401+
}
402+
return null;
403+
}
401404

402-
public ReadOnlySpan<byte> AsReadOnlySpan() => _rom.Span;
405+
return new MemoryBufferWrapper(this, flags);
406+
}
407+
408+
private static char GetFormatChar()
409+
=> Type.GetTypeCode(typeof(T)) switch {
410+
TypeCode.SByte => 'b',
411+
TypeCode.Byte => 'B',
412+
TypeCode.Char => 'u',
413+
TypeCode.Int16 => 'h',
414+
TypeCode.UInt16 => 'H',
415+
TypeCode.Int32 => 'i',
416+
TypeCode.UInt32 => 'I',
417+
TypeCode.Int64 => 'q',
418+
TypeCode.UInt64 => 'Q',
419+
TypeCode.Single => 'f',
420+
TypeCode.Double => 'd',
421+
_ => throw new ArgumentException("Unsupported type"),
422+
};
423+
424+
425+
private sealed unsafe class MemoryBufferWrapper : IPythonBuffer {
426+
private readonly MemoryBufferProtocolWrapper<T> _wrapper;
427+
private readonly BufferFlags _flags;
428+
429+
public MemoryBufferWrapper(MemoryBufferProtocolWrapper<T> wrapper, BufferFlags flags) {
430+
_wrapper = wrapper;
431+
_flags = flags;
432+
}
403433

404-
public Span<byte> AsSpan() => _memory.HasValue ? _memory.Value.Span : throw new InvalidOperationException("ReadOnlyMemory is not writable");
434+
public void Dispose() { }
405435

406-
public MemoryHandle Pin() => _rom.Pin();
436+
public object Object => _wrapper._memory ?? _wrapper._rom;
407437

408-
public int Offset => 0;
438+
public bool IsReadOnly => !_wrapper._memory.HasValue;
409439

410-
public string? Format => _flags.HasFlag(BufferFlags.Format) ? "B" : null;
440+
public ReadOnlySpan<byte> AsReadOnlySpan() => MemoryMarshal.Cast<T, byte>(_wrapper._rom.Span);
411441

412-
public int ItemCount => _rom.Length;
442+
public Span<byte> AsSpan()
443+
=> _wrapper._memory.HasValue
444+
? MemoryMarshal.Cast<T, byte>(_wrapper._memory.Value.Span)
445+
: throw new InvalidOperationException("ReadOnlyMemory is not writable");
413446

414-
public int ItemSize => 1;
447+
public MemoryHandle Pin() => _wrapper._rom.Pin();
415448

416-
public int NumOfDims => 1;
449+
public int Offset => 0;
417450

418-
public IReadOnlyList<int>? Shape => null;
451+
public string? Format => _flags.HasFlag(BufferFlags.Format) ? _wrapper._format.ToString() : null;
419452

420-
public IReadOnlyList<int>? Strides => null;
453+
public int ItemCount => _wrapper._rom.Length;
421454

422-
public IReadOnlyList<int>? SubOffsets => null;
423-
}
455+
public int ItemSize => sizeof(T);
424456

425-
public class MemoryBufferProtocolWrapper : IBufferProtocol {
426-
private readonly ReadOnlyMemory<byte> _rom;
427-
private readonly Memory<byte>? _memory;
457+
public int NumOfDims => 1;
428458

429-
public MemoryBufferProtocolWrapper(ReadOnlyMemory<byte> memory) {
430-
_rom = memory;
431-
_memory = null;
432-
}
459+
public IReadOnlyList<int>? Shape => null;
433460

434-
public MemoryBufferProtocolWrapper(Memory<byte> memory) {
435-
_rom = memory;
436-
_memory = memory;
437-
}
438-
439-
public IPythonBuffer? GetBuffer(BufferFlags flags, bool throwOnError) {
440-
if (_memory.HasValue) {
441-
return new MemoryBufferWrapper(_memory.Value, flags);
442-
}
443-
444-
if (flags.HasFlag(BufferFlags.Writable)) {
445-
if (throwOnError) {
446-
throw Operations.PythonOps.BufferError("ReadOnlyMemory is not writable.");
447-
}
448-
return null;
449-
}
461+
public IReadOnlyList<int>? Strides => null;
450462

451-
return new MemoryBufferWrapper(_rom, flags);
463+
public IReadOnlyList<int>? SubOffsets => null;
452464
}
453465
}
454466
}

0 commit comments

Comments
 (0)