Skip to content

Commit 8f1bf13

Browse files
committed
Implement masks as DU
1 parent 4e95a3b commit 8f1bf13

4 files changed

Lines changed: 52 additions & 77 deletions

File tree

src/GraphBLAS-sharp/CSRMatrix.fs

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ with
1919
ColumnCount = 0
2020
}
2121

22-
type CSRMatrix<'a when 'a : struct>(csrTuples: CSRFormat<'a>) =
22+
type CSRMatrix<'a when 'a : struct and 'a : equality>(csrTuples: CSRFormat<'a>) =
2323
inherit Matrix<'a>()
2424

2525
let rowCount = csrTuples.RowPointers.Length - 1
2626
let columnCount = csrTuples.ColumnCount
2727

28-
let spMV (vector: Vector<'a>) (mask: Mask1D option) (context: Semiring<'a>) : Vector<'a> =
28+
let spMV (vector: Vector<'a>) (mask: Mask1D<'a>) (context: Semiring<'a>) : Vector<'a> =
2929
let csrMatrixRowCount = rowCount
3030
let csrMatrixColumnCount = columnCount
3131
let vectorLength = vector.Length
@@ -79,26 +79,25 @@ type CSRMatrix<'a when 'a : struct>(csrTuples: CSRFormat<'a>) =
7979

8080
override this.RowCount = rowCount
8181
override this.ColumnCount = columnCount
82-
override this.Mask = failwith "Not Implemented"
8382

8483
override this.Item
85-
with get (mask: Mask2D option) : Matrix<'a> = failwith "Not Implemented"
86-
and set (mask: Mask2D option) (value: Matrix<'a>) = failwith "Not Implemented"
84+
with get (mask: Mask2D<'a>) : Matrix<'a> = failwith "Not Implemented"
85+
and set (mask: Mask2D<'a>) (value: Matrix<'a>) = failwith "Not Implemented"
8786
override this.Item
88-
with get (vectorMask: Mask1D option, colIdx: int) : Vector<'a> = failwith "Not Implemented"
89-
and set (vectorMask: Mask1D option, colIdx: int) (value: Vector<'a>) = failwith "Not Implemented"
87+
with get (vectorMask: Mask1D<'a>, colIdx: int) : Vector<'a> = failwith "Not Implemented"
88+
and set (vectorMask: Mask1D<'a>, colIdx: int) (value: Vector<'a>) = failwith "Not Implemented"
9089
override this.Item
91-
with get (rowIdx: int, vectorMask: Mask1D option) : Vector<'a> = failwith "Not Implemented"
92-
and set (rowIdx: int, vectorMask: Mask1D option) (value: Vector<'a>) = failwith "Not Implemented"
90+
with get (rowIdx: int, vectorMask: Mask1D<'a>) : Vector<'a> = failwith "Not Implemented"
91+
and set (rowIdx: int, vectorMask: Mask1D<'a>) (value: Vector<'a>) = failwith "Not Implemented"
9392
override this.Item
9493
with get (rowIdx: int, colIdx: int) : Scalar<'a> = failwith "Not Implemented"
9594
and set (rowIdx: int, colIdx: int) (value: Scalar<'a>) = failwith "Not Implemented"
9695
override this.Item
97-
with set (mask: Mask2D option) (value: Scalar<'a>) = failwith "Not Implemented"
96+
with set (mask: Mask2D<'a>) (value: Scalar<'a>) = failwith "Not Implemented"
9897
override this.Item
99-
with set (vectorMask: Mask1D option, colIdx: int) (value: Scalar<'a>) = failwith "Not Implemented"
98+
with set (vectorMask: Mask1D<'a>, colIdx: int) (value: Scalar<'a>) = failwith "Not Implemented"
10099
override this.Item
101-
with set (rowIdx: int, vectorMask: Mask1D option) (value: Scalar<'a>) = failwith "Not Implemented"
100+
with set (rowIdx: int, vectorMask: Mask1D<'a>) (value: Scalar<'a>) = failwith "Not Implemented"
102101

103102
override this.Mxm a b c = failwith "Not Implemented"
104103
override this.Mxv a b c = failwith "Not Implemented"

src/GraphBLAS-sharp/DenseVector.fs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
namespace GraphBLAS.FSharp
22

3-
type DenseVector<'a when 'a : struct>(vector: 'a[]) =
3+
type DenseVector<'a when 'a : struct and 'a : equality>(vector: 'a[]) =
44
inherit Vector<'a>()
55

66
new() = DenseVector(Array.zeroCreate<'a> 0)
77
new(listOfIndices: int list) = DenseVector(Array.zeroCreate<'a> 0)
88

99
override this.Length = failwith "Not Implemented"
10-
override this.Mask = failwith "Not Implemented"
1110
override this.AsArray = failwith "Not Implemented"
1211

1312
override this.Item
14-
with get (mask: Mask1D option) : Vector<'a> = failwith "Not Implemented"
15-
and set (mask: Mask1D option) (value: Vector<'a>) = failwith "Not Implemented"
13+
with get (mask: Mask1D<'a>) : Vector<'a> = failwith "Not Implemented"
14+
and set (mask: Mask1D<'a>) (value: Vector<'a>) = failwith "Not Implemented"
1615
override this.Item
1716
with get (rowIdx: int, colIdx: int) : Scalar<'a> = failwith "Not Implemented"
1817
and set (rowIdx: int, colIdx: int) (value: Scalar<'a>) = failwith "Not Implemented"
1918
override this.Item
20-
with set (mask: Mask1D option) (value: Scalar<'a>) = failwith "Not Implemented"
19+
with set (mask: Mask1D<'a>) (value: Scalar<'a>) = failwith "Not Implemented"
2120

2221
override this.Vxm a b c = failwith "Not Implemented"
2322
override this.EWiseAdd a b c = failwith "Not Implemented"
Lines changed: 36 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,30 @@
11
namespace GraphBLAS.FSharp
22

33
[<AbstractClass>]
4-
type Matrix<'a when 'a : struct>() =
4+
type Matrix<'a when 'a : struct and 'a : equality>() =
55
abstract RowCount: int
66
abstract ColumnCount: int
7-
abstract Mask: Mask2D
87

9-
abstract Item: Mask2D option -> Matrix<'a> with get, set
10-
abstract Item: Mask1D option * int -> Vector<'a> with get, set
11-
abstract Item: int * Mask1D option -> Vector<'a> with get, set
8+
abstract Item: Mask2D<'a> -> Matrix<'a> with get, set
9+
abstract Item: Mask1D<'a> * int -> Vector<'a> with get, set
10+
abstract Item: int * Mask1D<'a> -> Vector<'a> with get, set
1211
abstract Item: int * int -> Scalar<'a> with get, set
13-
abstract Item: Mask2D option -> Scalar<'a> with set
14-
abstract Item: Mask1D option * int -> Scalar<'a> with set
15-
abstract Item: int * Mask1D option -> Scalar<'a> with set
16-
17-
abstract Mxm: Matrix<'a> -> Mask2D option -> Semiring<'a> -> Matrix<'a>
18-
abstract Mxv: Vector<'a> -> Mask1D option -> Semiring<'a> -> Vector<'a>
19-
abstract EWiseAdd: Matrix<'a> -> Mask2D option -> Semiring<'a> -> Matrix<'a>
20-
abstract EWiseMult: Matrix<'a> -> Mask2D option -> Semiring<'a> -> Matrix<'a>
21-
abstract Apply: Mask1D option -> UnaryOp<'a, 'b> -> Matrix<'b>
22-
abstract ReduceIn: Mask1D option -> Monoid<'a> -> Vector<'a>
23-
abstract ReduceOut: Mask1D option -> Monoid<'a> -> Vector<'a>
12+
abstract Item: Mask2D<'a> -> Scalar<'a> with set
13+
abstract Item: Mask1D<'a> * int -> Scalar<'a> with set
14+
abstract Item: int * Mask1D<'a> -> Scalar<'a> with set
15+
16+
abstract Mxm: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> Matrix<'a>
17+
abstract Mxv: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> Vector<'a>
18+
abstract EWiseAdd: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> Matrix<'a>
19+
abstract EWiseMult: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> Matrix<'a>
20+
abstract Apply: Mask1D<'a> -> UnaryOp<'a, 'b> -> Matrix<'b>
21+
abstract ReduceIn: Mask1D<'a> -> Monoid<'a> -> Vector<'a>
22+
abstract ReduceOut: Mask1D<'a> -> Monoid<'a> -> Vector<'a>
2423
abstract T: Matrix<'a>
2524

26-
abstract EWiseAddInplace: Matrix<'a> -> Mask2D option -> Semiring<'a> -> unit
27-
abstract EWiseMultInplace: Matrix<'a> -> Mask2D option -> Semiring<'a> -> unit
28-
abstract ApplyInplace: Mask2D option -> UnaryOp<'a, 'b> -> unit
25+
abstract EWiseAddInplace: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> unit
26+
abstract EWiseMultInplace: Matrix<'a> -> Mask2D<'a> -> Semiring<'a> -> unit
27+
abstract ApplyInplace: Mask2D<'a> -> UnaryOp<'a, 'b> -> unit
2928

3029
static member inline (+) (x: Matrix<'a>, y: Matrix<'a>) = x.EWiseAdd y
3130
static member inline (*) (x: Matrix<'a>, y: Matrix<'a>) = x.EWiseMult y
@@ -34,58 +33,36 @@ type Matrix<'a when 'a : struct>() =
3433
static member inline (.+) (x: Matrix<'a>, y: Matrix<'a>) = x.EWiseAddInplace y
3534
static member inline (.*) (x: Matrix<'a>, y: Matrix<'a>) = x.EWiseMultInplace y
3635

37-
and [<AbstractClass>] Vector<'a when 'a : struct>() =
36+
and [<AbstractClass>] Vector<'a when 'a : struct and 'a : equality>() =
3837
abstract Length: int
39-
abstract Mask: Mask1D
4038
abstract AsArray: 'a[]
4139

42-
abstract Item: Mask1D option -> Vector<'a> with get, set
40+
abstract Item: Mask1D<'a> -> Vector<'a> with get, set
4341
abstract Item: int * int -> Scalar<'a> with get, set
44-
abstract Item: Mask1D option -> Scalar<'a> with set
42+
abstract Item: Mask1D<'a> -> Scalar<'a> with set
4543

46-
abstract Vxm: Matrix<'a> -> Mask1D option -> Semiring<'a> -> Vector<'a>
47-
abstract EWiseAdd: Vector<'a> -> Mask1D option -> Semiring<'a> -> Vector<'a>
48-
abstract EWiseMult: Vector<'a> -> Mask1D option -> Semiring<'a> -> Vector<'a>
49-
abstract Apply: Mask1D option -> UnaryOp<'a, 'b> -> Vector<'b>
44+
abstract Vxm: Matrix<'a> -> Mask1D<'a> -> Semiring<'a> -> Vector<'a>
45+
abstract EWiseAdd: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> Vector<'a>
46+
abstract EWiseMult: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> Vector<'a>
47+
abstract Apply: Mask1D<'a> -> UnaryOp<'a, 'b> -> Vector<'b>
5048
abstract Reduce: Monoid<'a> -> Scalar<'a>
5149

52-
abstract EWiseAddInplace: Vector<'a> -> Mask1D option -> Semiring<'a> -> unit
53-
abstract EWiseMultInplace: Vector<'a> -> Mask1D option -> Semiring<'a> -> unit
54-
abstract ApplyInplace: Mask1D option -> UnaryOp<'a, 'b> -> unit
50+
abstract EWiseAddInplace: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> unit
51+
abstract EWiseMultInplace: Vector<'a> -> Mask1D<'a> -> Semiring<'a> -> unit
52+
abstract ApplyInplace: Mask1D<'a> -> UnaryOp<'a, 'b> -> unit
5553

5654
static member inline (+) (x: Vector<'a>, y: Vector<'a>) = x.EWiseAdd y
5755
static member inline (*) (x: Vector<'a>, y: Vector<'a>) = x.EWiseMult y
5856
static member inline (+.*) (x: Vector<'a>, y: Matrix<'a>) = x.Vxm y
5957
static member inline (.+) (x: Vector<'a>, y: Vector<'a>) = x.EWiseAddInplace y
6058
static member inline (.*) (x: Vector<'a>, y: Vector<'a>) = x.EWiseMultInplace y
6159

62-
and Mask1D(size: int, indexList: int list) =
63-
64-
member this.Item
65-
with get (idx: int) = indexList.[idx]
66-
67-
member this.GetComplement() =
68-
let indices = Set.ofList indexList
69-
let allIndices = List.init size id |> Set.ofList
70-
let complementIndices = Set.difference allIndices indices |> Set.toList
71-
Mask1D(size, complementIndices)
72-
73-
member this.GetEnumerator() = (indexList |> List.toSeq).GetEnumerator()
74-
75-
static member (~~) (mask: Mask1D) = mask.GetComplement()
76-
77-
and Mask2D(size: int, indexList: (int * int) list) =
78-
79-
member this.Item
80-
with get (idx: int) = indexList.[idx]
81-
82-
member this.GetComplement() =
83-
let indices = Set.ofList indexList
84-
let allIndices = List.init size (fun i -> (i, i)) |> Set.ofList
85-
let complementIndices = Set.difference allIndices indices |> Set.toList
86-
Mask2D(size, complementIndices)
87-
88-
member this.GetEnumerator() = (indexList |> List.toSeq).GetEnumerator()
89-
90-
static member (~~) (mask: Mask2D) = mask.GetComplement()
60+
and Mask1D<'a when 'a : struct and 'a : equality> =
61+
| Mask1D of Vector<'a>
62+
| Complemented1D of Vector<'a>
63+
| None
9164

65+
and Mask2D<'a when 'a : struct and 'a : equality> =
66+
| Mask2D of Matrix<'a>
67+
| Complemented2D of Matrix<'a>
68+
| None

src/GraphBLAS-sharp/Scalar.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
namespace GraphBLAS.FSharp
22

3-
type Scalar<'a when 'a : struct> = Scalar of 'a
3+
type Scalar<'a when 'a : struct and 'a : equality> = Scalar of 'a
44
with
55
static member op_Implicit (Scalar source) = source

0 commit comments

Comments
 (0)