git.haldean.org symrep / 2b1c08e
simple nodes and node hierarchy visualization Will Haldean Brown 6 years ago
3 changed file(s) with 116 addition(s) and 0 deletion(s). Raw diff Collapse all Expand all
0 __pycache__
1 *.pyc
2 *.dot
3 *.png
0 import math
1
2 def const(val):
3 return Node("const {}".format(val), lambda _: val, [])
4
5 def sum(n1, n2):
6 return Node("add", lambda t: n1(t) + n2(t), [n1, n2])
7
8 def product(n1, n2):
9 return Node("product", lambda t: n1(t) * n2(t), [n1, n2])
10
11 def sine(freq):
12 return Node(
13 "sin", lambda t: math.sin(t * freq(t) / (2 * math.pi)), [freq])
14
15 class Node(object):
16 _next_id = 0
17
18 def __init__(self, name, func, deps):
19 self.id = Node._next_id
20 Node._next_id += 1
21
22 self.name = name
23 self.func = func
24 self.deps = deps
25
26 def __call__(self, t):
27 return self.func(t)
28
29 def collect_nodes(root):
30 return set([root]).union(
31 reduce(set.union, map(collect_nodes, root.deps), set()))
32
33 def collect_edges(root):
34 edges = [(root.id, dep.id) for dep in root.deps]
35 for dep in root.deps:
36 edges.extend(collect_edges(dep))
37 return edges
38
39 def to_dot(root, stream, name="symrep"):
40 stream.write("digraph {name} {{\n".format(name=name))
41 for node in collect_nodes(root):
42 stream.write("{id} [label=\"{name}\"];\n".format(
43 id=node.id, name=node.name))
44 for n1, n2 in collect_edges(root):
45 stream.write("{n2} -> {n1};\n".format(n1=n1, n2=n2))
46 stream.write("}\n")
0 import symrep
1 import unittest
2
3
4 class SymrepTest(unittest.TestCase):
5 def test_const(self):
6 n = symrep.const(4)
7 self.assertEqual(n(0), 4)
8 self.assertEqual(n(1), 4)
9 self.assertEqual(n(2), 4)
10
11 def test_sum(self):
12 n1 = symrep.const(1)
13 n2 = symrep.const(2)
14
15 n = symrep.sum(n1, n2)
16 self.assertEqual(n(0), 3)
17
18 n = symrep.sum(n1, n1)
19 self.assertEqual(n(0), 2)
20
21 def test_product(self):
22 n = symrep.product(symrep.const(1), symrep.const(-3))
23 self.assertEqual(n(0), -3)
24
25 n1 = symrep.const(-4)
26 n = symrep.product(symrep.const(-4), symrep.const(.5))
27 self.assertEqual(n(0), -2)
28
29 def test_sine(self):
30 n = symrep.sine(symrep.const(1))
31 self.assertEqual(n(0), 0)
32 self.assertEqual(n(0.25), 1)
33 self.assertEqual(n(0.5), 0)
34 self.assertEqual(n(0.75), -1)
35 self.assertEqual(n(1), 0)
36
37 def test_collect_nodes(self):
38 n1 = symrep.const(1)
39 n2 = symrep.sine(n1)
40 n3 = symrep.sum(n1, n2)
41 n4 = symrep.const(-2)
42 n5 = symrep.product(n3, n4)
43 nodes = symrep.collect_nodes(n5)
44 self.assertEqual(len(nodes), 5)
45 self.assertSetEqual(
46 set(n.id for n in nodes),
47 set((n1.id, n2.id, n3.id, n4.id, n5.id)))
48
49 def test_dot(self):
50 n = symrep.product(
51 symrep.sine(symrep.const(2)),
52 symrep.product(
53 symrep.sine(symrep.const(3)),
54 symrep.sum(
55 symrep.const(1),
56 symrep.const(1),
57 )
58 )
59 )
60 with open("test.dot", "w") as f:
61 symrep.to_dot(n, f, name="test_dot")
62
63 if __name__ == "__main__":
64 unittest.main()