|
10 | 10 | using System.Collections.Generic; |
11 | 11 | using System.Diagnostics.CodeAnalysis; |
12 | 12 | using System.Numerics; |
| 13 | +using System.Runtime.InteropServices; |
13 | 14 |
|
14 | 15 | namespace IronPython.Runtime { |
15 | 16 |
|
@@ -376,79 +377,90 @@ IEnumerator IEnumerable.GetEnumerator() { |
376 | 377 | #endregion |
377 | 378 | } |
378 | 379 |
|
379 | | - public sealed class MemoryBufferWrapper : IPythonBuffer { |
380 | | - private readonly ReadOnlyMemory<byte> _rom; |
381 | | - private readonly Memory<byte>? _memory; |
382 | | - private readonly BufferFlags _flags; |
| 380 | + public sealed class MemoryBufferProtocolWrapper<T> : IBufferProtocol where T : unmanaged { |
| 381 | + private readonly ReadOnlyMemory<T> _rom; |
| 382 | + private readonly Memory<T>? _memory; |
| 383 | + private readonly char _format; |
383 | 384 |
|
384 | | - public MemoryBufferWrapper(ReadOnlyMemory<byte> memory, BufferFlags flags) { |
| 385 | + public MemoryBufferProtocolWrapper(ReadOnlyMemory<T> memory) { |
385 | 386 | _rom = memory; |
386 | 387 | _memory = null; |
387 | | - _flags = flags; |
| 388 | + _format = GetFormatChar(); |
388 | 389 | } |
389 | 390 |
|
390 | | - public MemoryBufferWrapper(Memory<byte> memory, BufferFlags flags) { |
| 391 | + public MemoryBufferProtocolWrapper(Memory<T> memory) { |
391 | 392 | _rom = memory; |
392 | 393 | _memory = memory; |
393 | | - _flags = flags; |
| 394 | + _format = GetFormatChar(); |
394 | 395 | } |
395 | 396 |
|
396 | | - public void Dispose() { } |
397 | | - |
398 | | - public object Object => _memory ?? _rom; |
399 | | - |
400 | | - public bool IsReadOnly => !_memory.HasValue; |
| 397 | + public IPythonBuffer? GetBuffer(BufferFlags flags, bool throwOnError) { |
| 398 | + if (flags.HasFlag(BufferFlags.Writable) && !_memory.HasValue) { |
| 399 | + if (throwOnError) { |
| 400 | + throw Operations.PythonOps.BufferError("ReadOnlyMemory is not writable."); |
| 401 | + } |
| 402 | + return null; |
| 403 | + } |
401 | 404 |
|
402 | | - public ReadOnlySpan<byte> AsReadOnlySpan() => _rom.Span; |
| 405 | + return new MemoryBufferWrapper(this, flags); |
| 406 | + } |
| 407 | + |
| 408 | + private static char GetFormatChar() |
| 409 | + => Type.GetTypeCode(typeof(T)) switch { |
| 410 | + TypeCode.SByte => 'b', |
| 411 | + TypeCode.Byte => 'B', |
| 412 | + TypeCode.Char => 'u', |
| 413 | + TypeCode.Int16 => 'h', |
| 414 | + TypeCode.UInt16 => 'H', |
| 415 | + TypeCode.Int32 => 'i', |
| 416 | + TypeCode.UInt32 => 'I', |
| 417 | + TypeCode.Int64 => 'q', |
| 418 | + TypeCode.UInt64 => 'Q', |
| 419 | + TypeCode.Single => 'f', |
| 420 | + TypeCode.Double => 'd', |
| 421 | + _ => throw new ArgumentException("Unsupported type"), |
| 422 | + }; |
| 423 | + |
| 424 | + |
| 425 | + private sealed unsafe class MemoryBufferWrapper : IPythonBuffer { |
| 426 | + private readonly MemoryBufferProtocolWrapper<T> _wrapper; |
| 427 | + private readonly BufferFlags _flags; |
| 428 | + |
| 429 | + public MemoryBufferWrapper(MemoryBufferProtocolWrapper<T> wrapper, BufferFlags flags) { |
| 430 | + _wrapper = wrapper; |
| 431 | + _flags = flags; |
| 432 | + } |
403 | 433 |
|
404 | | - public Span<byte> AsSpan() => _memory.HasValue ? _memory.Value.Span : throw new InvalidOperationException("ReadOnlyMemory is not writable"); |
| 434 | + public void Dispose() { } |
405 | 435 |
|
406 | | - public MemoryHandle Pin() => _rom.Pin(); |
| 436 | + public object Object => _wrapper._memory ?? _wrapper._rom; |
407 | 437 |
|
408 | | - public int Offset => 0; |
| 438 | + public bool IsReadOnly => !_wrapper._memory.HasValue; |
409 | 439 |
|
410 | | - public string? Format => _flags.HasFlag(BufferFlags.Format) ? "B" : null; |
| 440 | + public ReadOnlySpan<byte> AsReadOnlySpan() => MemoryMarshal.Cast<T, byte>(_wrapper._rom.Span); |
411 | 441 |
|
412 | | - public int ItemCount => _rom.Length; |
| 442 | + public Span<byte> AsSpan() |
| 443 | + => _wrapper._memory.HasValue |
| 444 | + ? MemoryMarshal.Cast<T, byte>(_wrapper._memory.Value.Span) |
| 445 | + : throw new InvalidOperationException("ReadOnlyMemory is not writable"); |
413 | 446 |
|
414 | | - public int ItemSize => 1; |
| 447 | + public MemoryHandle Pin() => _wrapper._rom.Pin(); |
415 | 448 |
|
416 | | - public int NumOfDims => 1; |
| 449 | + public int Offset => 0; |
417 | 450 |
|
418 | | - public IReadOnlyList<int>? Shape => null; |
| 451 | + public string? Format => _flags.HasFlag(BufferFlags.Format) ? _wrapper._format.ToString() : null; |
419 | 452 |
|
420 | | - public IReadOnlyList<int>? Strides => null; |
| 453 | + public int ItemCount => _wrapper._rom.Length; |
421 | 454 |
|
422 | | - public IReadOnlyList<int>? SubOffsets => null; |
423 | | - } |
| 455 | + public int ItemSize => sizeof(T); |
424 | 456 |
|
425 | | - public class MemoryBufferProtocolWrapper : IBufferProtocol { |
426 | | - private readonly ReadOnlyMemory<byte> _rom; |
427 | | - private readonly Memory<byte>? _memory; |
| 457 | + public int NumOfDims => 1; |
428 | 458 |
|
429 | | - public MemoryBufferProtocolWrapper(ReadOnlyMemory<byte> memory) { |
430 | | - _rom = memory; |
431 | | - _memory = null; |
432 | | - } |
| 459 | + public IReadOnlyList<int>? Shape => null; |
433 | 460 |
|
434 | | - public MemoryBufferProtocolWrapper(Memory<byte> memory) { |
435 | | - _rom = memory; |
436 | | - _memory = memory; |
437 | | - } |
438 | | - |
439 | | - public IPythonBuffer? GetBuffer(BufferFlags flags, bool throwOnError) { |
440 | | - if (_memory.HasValue) { |
441 | | - return new MemoryBufferWrapper(_memory.Value, flags); |
442 | | - } |
443 | | - |
444 | | - if (flags.HasFlag(BufferFlags.Writable)) { |
445 | | - if (throwOnError) { |
446 | | - throw Operations.PythonOps.BufferError("ReadOnlyMemory is not writable."); |
447 | | - } |
448 | | - return null; |
449 | | - } |
| 461 | + public IReadOnlyList<int>? Strides => null; |
450 | 462 |
|
451 | | - return new MemoryBufferWrapper(_rom, flags); |
| 463 | + public IReadOnlyList<int>? SubOffsets => null; |
452 | 464 | } |
453 | 465 | } |
454 | 466 | } |
0 commit comments