Skip to content

Commit c0272b2

Browse files
committed
Changed mask implementation to class
1 parent 6a08fcd commit c0272b2

3 files changed

Lines changed: 92 additions & 54 deletions

File tree

src/GraphBLAS-sharp/CSRMatrix.fs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type CSRMatrix<'a when 'a : struct and 'a : equality>(csrTuples: CSRFormat<'a>)
2525
let rowCount = csrTuples.RowPointers.Length - 1
2626
let columnCount = csrTuples.ColumnCount
2727

28-
let spMV (vector: Vector<'a>) (mask: Mask1D<'a>) (context: Semiring<'a>) : Vector<'a> =
28+
let spMV (vector: Vector<'a>) (mask: Mask1D) (context: Semiring<'a>) : Vector<'a> =
2929
let csrMatrixRowCount = rowCount
3030
let csrMatrixColumnCount = columnCount
3131
let vectorLength = vector.Length
@@ -69,7 +69,7 @@ type CSRMatrix<'a when 'a : struct and 'a : equality>(csrTuples: CSRFormat<'a>)
6969
currentContext.CommandQueue.Add (resultVector.ToHost currentContext.Provider) |> ignore
7070
currentContext.CommandQueue.Finish () |> ignore
7171

72-
upcast DenseVector(resultVector)
72+
upcast DenseVector(resultVector, context.PlusMonoid)
7373

7474
new() = CSRMatrix(CSRFormat.ZeroCreate())
7575

@@ -80,24 +80,27 @@ type CSRMatrix<'a when 'a : struct and 'a : equality>(csrTuples: CSRFormat<'a>)
8080
override this.RowCount = rowCount
8181
override this.ColumnCount = columnCount
8282

83+
override this.CreateMask (isRegular: bool) =
84+
failwith "Not implemented"
85+
8386
override this.Item
84-
with get (mask: Mask2D<'a>) : Matrix<'a> = failwith "Not Implemented"
85-
and set (mask: Mask2D<'a>) (value: Matrix<'a>) = failwith "Not Implemented"
87+
with get (mask: Mask2D) : Matrix<'a> = failwith "Not Implemented"
88+
and set (mask: Mask2D) (value: Matrix<'a>) = failwith "Not Implemented"
8689
override this.Item
87-
with get (vectorMask: Mask1D<'a>, colIdx: int) : Vector<'a> = failwith "Not Implemented"
88-
and set (vectorMask: Mask1D<'a>, colIdx: int) (value: Vector<'a>) = failwith "Not Implemented"
90+
with get (vectorMask: Mask1D, colIdx: int) : Vector<'a> = failwith "Not Implemented"
91+
and set (vectorMask: Mask1D, colIdx: int) (value: Vector<'a>) = failwith "Not Implemented"
8992
override this.Item
90-
with get (rowIdx: int, vectorMask: Mask1D<'a>) : Vector<'a> = failwith "Not Implemented"
91-
and set (rowIdx: int, vectorMask: Mask1D<'a>) (value: Vector<'a>) = failwith "Not Implemented"
93+
with get (rowIdx: int, vectorMask: Mask1D) : Vector<'a> = failwith "Not Implemented"
94+
and set (rowIdx: int, vectorMask: Mask1D) (value: Vector<'a>) = failwith "Not Implemented"
9295
override this.Item
9396
with get (rowIdx: int, colIdx: int) : Scalar<'a> = failwith "Not Implemented"
9497
and set (rowIdx: int, colIdx: int) (value: Scalar<'a>) = failwith "Not Implemented"
9598
override this.Item
96-
with set (mask: Mask2D<'a>) (value: Scalar<'a>) = failwith "Not Implemented"
99+
with set (mask: Mask2D) (value: Scalar<'a>) = failwith "Not Implemented"
97100
override this.Item
98-
with set (vectorMask: Mask1D<'a>, colIdx: int) (value: Scalar<'a>) = failwith "Not Implemented"
101+
with set (vectorMask: Mask1D, colIdx: int) (value: Scalar<'a>) = failwith "Not Implemented"
99102
override this.Item
100-
with set (rowIdx: int, vectorMask: Mask1D<'a>) (value: Scalar<'a>) = failwith "Not Implemented"
103+
with set (rowIdx: int, vectorMask: Mask1D) (value: Scalar<'a>) = failwith "Not Implemented"
101104

102105
override this.Mxm a b c = failwith "Not Implemented"
103106
override this.Mxv a b c = failwith "Not Implemented"

src/GraphBLAS-sharp/DenseVector.fs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,31 @@
11
namespace GraphBLAS.FSharp
22

3-
type DenseVector<'a when 'a : struct and 'a : equality>(vector: 'a[]) =
3+
type DenseVector<'a when 'a : struct and 'a : equality>(vector: 'a[], monoid: Monoid<'a>) =
44
inherit Vector<'a>()
55

6-
new() = DenseVector(Array.zeroCreate<'a> 0)
7-
new(listOfIndices: int list) = DenseVector(Array.zeroCreate<'a> 0)
6+
new(monoid: Monoid<'a>) = DenseVector(Array.zeroCreate<'a> 0, monoid)
7+
new(listOfIndices: int list, monoid: Monoid<'a>) = DenseVector(Array.zeroCreate<'a> 0, monoid)
8+
9+
member this.Monoid = monoid
10+
member this.Values = vector
811

912
override this.Length = failwith "Not Implemented"
1013
override this.AsArray = failwith "Not Implemented"
1114

15+
override this.CreateMask (isRegular: bool) =
16+
let indices =
17+
[| for i in 0 .. this.Length - 1 do
18+
if this.Values.[i] <> this.Monoid.Zero then yield i |]
19+
Mask1D(false, indices, this.Length, isRegular)
20+
1221
override this.Item
13-
with get (mask: Mask1D<'a>) : Vector<'a> = failwith "Not Implemented"
14-
and set (mask: Mask1D<'a>) (value: Vector<'a>) = failwith "Not Implemented"
22+
with get (mask: Mask1D) : Vector<'a> = failwith "Not Implemented"
23+
and set (mask: Mask1D) (value: Vector<'a>) = failwith "Not Implemented"
1524
override this.Item
16-
with get (rowIdx: int, colIdx: int) : Scalar<'a> = failwith "Not Implemented"
17-
and set (rowIdx: int, colIdx: int) (value: Scalar<'a>) = failwith "Not Implemented"
25+
with get (idx: int) : Scalar<'a> = failwith "Not Implemented"
26+
and set (idx: int) (value: Scalar<'a>) = failwith "Not Implemented"
1827
override this.Item
19-
with set (mask: Mask1D<'a>) (value: Scalar<'a>) = failwith "Not Implemented"
28+
with set (mask: Mask1D) (value: Scalar<'a>) = failwith "Not Implemented"
2029

2130
override this.Vxm a b c = failwith "Not Implemented"
2231
override this.EWiseAdd a b c = failwith "Not Implemented"

src/GraphBLAS-sharp/MatrixAndVector.fs

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,28 @@ type Matrix<'a when 'a : struct and 'a : equality>() =
55
abstract RowCount: int
66
abstract ColumnCount: int
77

8-
abstract Item: Mask2D<'a> -> Matrix<'a> with get, set
9-
abstract Item: Mask1D<'a> * int -> Vector<'a> with get, set
10-
abstract Item: int * Mask1D<'a> -> Vector<'a> with get, set
8+
abstract CreateMask: bool -> Mask2D
9+
10+
abstract Item: Mask2D -> Matrix<'a> with get, set
11+
abstract Item: Mask1D * int -> Vector<'a> with get, set
12+
abstract Item: int * Mask1D -> Vector<'a> with get, set
1113
abstract Item: int * int -> Scalar<'a> with get, set
12-
abstract Item: Mask2D<'a> -> Scalar<'a> with set
13-
abstract Item: Mask1D<'a> * int -> Scalar<'a> with set
14-
abstract Item: int * Mask1D<'a> -> Scalar<'a> with set
15-
16-
abstract Mxm: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> Matrix<'a>
17-
abstract Mxv: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> Vector<'a>
18-
abstract EWiseAdd: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> Matrix<'a>
19-
abstract EWiseMult: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> Matrix<'a>
20-
abstract Apply: Mask1D<'a> -> UnaryOp<'a, 'b> -> Matrix<'b>
21-
abstract ReduceIn: Mask1D<'a> -> Monoid<'a> -> Vector<'a>
22-
abstract ReduceOut: Mask1D<'a> -> Monoid<'a> -> Vector<'a>
14+
abstract Item: Mask2D -> Scalar<'a> with set
15+
abstract Item: Mask1D * int -> Scalar<'a> with set
16+
abstract Item: int * Mask1D -> Scalar<'a> with set
17+
18+
abstract Mxm: Matrix<'a> -> Mask2D -> Semiring<'a> -> Matrix<'a>
19+
abstract Mxv: Vector<'a> -> Mask1D -> Semiring<'a> -> Vector<'a>
20+
abstract EWiseAdd: Matrix<'a> -> Mask2D -> Semiring<'a> -> Matrix<'a>
21+
abstract EWiseMult: Matrix<'a> -> Mask2D -> Semiring<'a> -> Matrix<'a>
22+
abstract Apply: Mask1D -> UnaryOp<'a, 'b> -> Matrix<'b>
23+
abstract ReduceIn: Mask1D -> Monoid<'a> -> Vector<'a>
24+
abstract ReduceOut: Mask1D -> Monoid<'a> -> Vector<'a>
2325
abstract T: Matrix<'a>
2426

25-
abstract EWiseAddInplace: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> unit
26-
abstract EWiseMultInplace: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> unit
27-
abstract ApplyInplace: Mask2D<'a> -> UnaryOp<'a, 'b> -> unit
27+
abstract EWiseAddInplace: Matrix<'a> -> Mask2D -> Semiring<'a> -> unit
28+
abstract EWiseMultInplace: Matrix<'a> -> Mask2D -> Semiring<'a> -> unit
29+
abstract ApplyInplace: Mask2D -> UnaryOp<'a, 'b> -> unit
2830

2931
static member inline (+) (x: Matrix<'a>, y: Matrix<'a>) = x.EWiseAdd y
3032
static member inline (*) (x: Matrix<'a>, y: Matrix<'a>) = x.EWiseMult y
@@ -37,32 +39,56 @@ and [<AbstractClass>] Vector<'a when 'a : struct and 'a : equality>() =
3739
abstract Length: int
3840
abstract AsArray: 'a[]
3941

40-
abstract Item: Mask1D<'a> -> Vector<'a> with get, set
41-
abstract Item: int * int -> Scalar<'a> with get, set
42-
abstract Item: Mask1D<'a> -> Scalar<'a> with set
42+
abstract CreateMask: bool -> Mask1D
4343

44-
abstract Vxm: Matrix<'a> -> Mask1D<'a> -> Semiring<'a> -> Vector<'a>
45-
abstract EWiseAdd: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> Vector<'a>
46-
abstract EWiseMult: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> Vector<'a>
47-
abstract Apply: Mask1D<'a> -> UnaryOp<'a, 'b> -> Vector<'b>
44+
abstract Item: Mask1D -> Vector<'a> with get, set
45+
abstract Item: int -> Scalar<'a> with get, set
46+
abstract Item: Mask1D -> Scalar<'a> with set
47+
48+
abstract Vxm: Matrix<'a> -> Mask1D -> Semiring<'a> -> Vector<'a>
49+
abstract EWiseAdd: Vector<'a> -> Mask1D -> Semiring<'a> -> Vector<'a>
50+
abstract EWiseMult: Vector<'a> -> Mask1D -> Semiring<'a> -> Vector<'a>
51+
abstract Apply: Mask1D -> UnaryOp<'a, 'b> -> Vector<'b>
4852
abstract Reduce: Monoid<'a> -> Scalar<'a>
4953

50-
abstract EWiseAddInplace: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> unit
51-
abstract EWiseMultInplace: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> unit
52-
abstract ApplyInplace: Mask1D<'a> -> UnaryOp<'a, 'b> -> unit
54+
abstract EWiseAddInplace: Vector<'a> -> Mask1D -> Semiring<'a> -> unit
55+
abstract EWiseMultInplace: Vector<'a> -> Mask1D -> Semiring<'a> -> unit
56+
abstract ApplyInplace: Mask1D -> UnaryOp<'a, 'b> -> unit
5357

5458
static member inline (+) (x: Vector<'a>, y: Vector<'a>) = x.EWiseAdd y
5559
static member inline (*) (x: Vector<'a>, y: Vector<'a>) = x.EWiseMult y
5660
static member inline (+.*) (x: Vector<'a>, y: Matrix<'a>) = x.Vxm y
5761
static member inline (.+) (x: Vector<'a>, y: Vector<'a>) = x.EWiseAddInplace y
5862
static member inline (.*) (x: Vector<'a>, y: Vector<'a>) = x.EWiseMultInplace y
5963

60-
and Mask1D<'a when 'a : struct and 'a : equality> =
61-
| Mask1D of Vector<'a>
62-
| Complemented1D of Vector<'a>
63-
| None
64+
and Mask1D(isNone: bool, indices: int[], length: int, isRegular: bool) =
65+
member this.IsNone = isNone
66+
member this.IsRegular = isRegular
67+
68+
member this.Indices = indices
69+
member this.Length = length
70+
71+
member this.Item
72+
with get (idx: int) : bool =
73+
this.IsNone || this.Indices |> Array.exists ( ( = ) idx) |> ( = ) this.IsRegular
74+
75+
static member Create (vector: Vector<'a>) = vector.CreateMask true
76+
static member Complemented (vector: Vector<'a>) = vector.CreateMask false
77+
static member None = Mask1D(true, Array.empty, 0, true)
78+
79+
and Mask2D(isNone: bool, rows: int[], columns: int[], rowCount: int, columnCount: int, isRegular: bool) =
80+
member this.IsNone = isNone
81+
member this.IsRegular = isRegular
82+
83+
member this.Rows = rows
84+
member this.Columns = columns
85+
member this.RowCount = rowCount
86+
member this.ColumnCount = columnCount
87+
88+
member this.Item
89+
with get (rowIdx: int, colIdx: int) : bool =
90+
this.IsNone || Array.zip this.Rows this.Columns |> Array.exists ( ( = ) (rowIdx, colIdx)) |> ( = ) this.IsRegular
6491

65-
and Mask2D<'a when 'a : struct and 'a : equality> =
66-
| Mask2D of Matrix<'a>
67-
| Complemented2D of Matrix<'a>
68-
| None
92+
static member Create (matrix: Matrix<'a>) = matrix.CreateMask true
93+
static member Complemented (matrix: Matrix<'a>) = matrix.CreateMask false
94+
static member None = Mask2D(true, Array.empty, Array.empty, 0, 0, true)

0 commit comments

Comments
 (0)