git.haldean.org
simple nodes and node hierarchy visualization Will Haldean Brown 6 years ago
3 changed file(s) with 116 addition(s) and 0 deletion(s).
 0 __pycache__ 1 *.pyc 2 *.dot 3 *.png
 0 import math 1 2 def const(val): 3 return Node("const {}".format(val), lambda _: val, []) 4 5 def sum(n1, n2): 6 return Node("add", lambda t: n1(t) + n2(t), [n1, n2]) 7 8 def product(n1, n2): 9 return Node("product", lambda t: n1(t) * n2(t), [n1, n2]) 10 11 def sine(freq): 12 return Node( 13 "sin", lambda t: math.sin(t * freq(t) / (2 * math.pi)), [freq]) 14 15 class Node(object): 16 _next_id = 0 17 18 def __init__(self, name, func, deps): 19 self.id = Node._next_id 20 Node._next_id += 1 21 22 self.name = name 23 self.func = func 24 self.deps = deps 25 26 def __call__(self, t): 27 return self.func(t) 28 29 def collect_nodes(root): 30 return set([root]).union( 31 reduce(set.union, map(collect_nodes, root.deps), set())) 32 33 def collect_edges(root): 34 edges = [(root.id, dep.id) for dep in root.deps] 35 for dep in root.deps: 36 edges.extend(collect_edges(dep)) 37 return edges 38 39 def to_dot(root, stream, name="symrep"): 40 stream.write("digraph {name} {{\n".format(name=name)) 41 for node in collect_nodes(root): 42 stream.write("{id} [label=\"{name}\"];\n".format( 43 id=node.id, name=node.name)) 44 for n1, n2 in collect_edges(root): 45 stream.write("{n2} -> {n1};\n".format(n1=n1, n2=n2)) 46 stream.write("}\n")
 0 import symrep 1 import unittest 2 3 4 class SymrepTest(unittest.TestCase): 5 def test_const(self): 6 n = symrep.const(4) 7 self.assertEqual(n(0), 4) 8 self.assertEqual(n(1), 4) 9 self.assertEqual(n(2), 4) 10 11 def test_sum(self): 12 n1 = symrep.const(1) 13 n2 = symrep.const(2) 14 15 n = symrep.sum(n1, n2) 16 self.assertEqual(n(0), 3) 17 18 n = symrep.sum(n1, n1) 19 self.assertEqual(n(0), 2) 20 21 def test_product(self): 22 n = symrep.product(symrep.const(1), symrep.const(-3)) 23 self.assertEqual(n(0), -3) 24 25 n1 = symrep.const(-4) 26 n = symrep.product(symrep.const(-4), symrep.const(.5)) 27 self.assertEqual(n(0), -2) 28 29 def test_sine(self): 30 n = symrep.sine(symrep.const(1)) 31 self.assertEqual(n(0), 0) 32 self.assertEqual(n(0.25), 1) 33 self.assertEqual(n(0.5), 0) 34 self.assertEqual(n(0.75), -1) 35 self.assertEqual(n(1), 0) 36 37 def test_collect_nodes(self): 38 n1 = symrep.const(1) 39 n2 = symrep.sine(n1) 40 n3 = symrep.sum(n1, n2) 41 n4 = symrep.const(-2) 42 n5 = symrep.product(n3, n4) 43 nodes = symrep.collect_nodes(n5) 44 self.assertEqual(len(nodes), 5) 45 self.assertSetEqual( 46 set(n.id for n in nodes), 47 set((n1.id, n2.id, n3.id, n4.id, n5.id))) 48 49 def test_dot(self): 50 n = symrep.product( 51 symrep.sine(symrep.const(2)), 52 symrep.product( 53 symrep.sine(symrep.const(3)), 54 symrep.sum( 55 symrep.const(1), 56 symrep.const(1), 57 ) 58 ) 59 ) 60 with open("test.dot", "w") as f: 61 symrep.to_dot(n, f, name="test_dot") 62 63 if __name__ == "__main__": 64 unittest.main()