1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use variable::Variable;

cpp! {{
  #include <CNTKLibrary.h>
  #include <cstdio>
  #include <vector>
  #include <unordered_set>

  using namespace CNTK;
  using namespace std;
}}

type VariableSetInner = [u64; 1usize];

/// Wrapper around unordered_set<Variable>
pub struct VariableSet {
    pub(super) payload: *mut VariableSetInner
}

impl VariableSet {
    /// Creates empty VariableSet
    pub fn new() -> VariableSet {
        VariableSet {
            payload: unsafe {
                cpp!([] -> *mut VariableSetInner as "unordered_set<Variable>*" {
                    return new unordered_set<Variable>;
                })
            }
        }
    }

    /// Adds Variable to set
    pub fn add<T: Into<Variable>>(&mut self, variable: T) {
        let vv = variable.into();
        let var_payload = vv.payload;
        let mut payload = self.payload;

        unsafe {
            cpp!([mut payload as "unordered_set<Variable>*", var_payload as "Variable"] {
                payload->insert(var_payload);
            })
        }
    }
}

impl Drop for VariableSet {
    fn drop(&mut self) {
        let payload = self.payload;
        unsafe {
            cpp!([payload as "unordered_set<Variable>*"] {
                delete payload;
            })
        };
    }
}

#[macro_export]
macro_rules! variableset {
    (@single $($x:tt)*) => (());
    (@count $($rest:expr),*) => (<[()]>::len(&[$(variableset!(@single $rest)),*]));
    
    ($($key:expr,)+) => { variableset!($($key),+) };
    ($($key:expr),*) => {
        {
            let mut _set = VariableSet::new();
            $(
                _set.add($key);
            )*
            _set
        }
    };
}

#[cfg(test)]
mod tests {
    use super::*;
    use variable::*;
    use shape::Shape;

    #[test]
    fn test_create() {
        let _set = VariableSet::new();
    }

    #[test]
    fn test_add() {
        let mut set = VariableSet::new();
        let var = Variable::input_variable(&Shape::scalar());

        set.add(&var);
    }
}