Question
Just fill in the three methods: condition, sum_out, normalize (replace #TODO). Make sure to include a short explanation of the code! Also, np is numpy
Just fill in the three methods: condition, sum_out, normalize (replace #TODO).
Make sure to include a short explanation of the code! Also, np is numpy here.
class Factor(object):
"""Represents a factor with associated operations for variable elimination."""
def __init__(self, domains, values=None):
"""
Args:
domains:
A dictionary where the keys are variable names in the factor, and
the values are tuples containing all possible value of the variable.
values:
Convenience argument for initializing the Factor's table.
List of tuples, where each tuple is a row of the table.
First element of each tuple is the value of the first variable, etc.
Final element of each tuple is the value of the factor for the given
combination of values. See `unit_tests` for example usage.
"""
self.domains = dict(domains)
shape = [len(domains[v]) for v in domains]
self.data = np.zeros(shape)
if values is not None:
for v in values:
key = v[:-1]
val = v[-1]
self[key] = val
# ------- Operators
def condition(self, name, val):
"""Return a new factor that conditions on ``name=val``"""
j = tuple(self.names).index(name)
new_domains = dict(self.domains) # copy own domains...
del new_domains[name] # ... except for `name`
new_f = Factor(new_domains)
# TODO
return new_f
def sum_out(self, name):
"""Return a new factor that sums out variable `name`"""
j = tuple(self.names).index(name)
new_domains = dict(self.domains) # copy own domains...
del new_domains[name] # ... except for `name`
new_f = Factor(new_domains)
# TODO
return new_f
def normalize(self):
"""Return a new factor whose values add to 1"""
new_f = Factor(self.domains)
new_f.data = self.data / np.sum(self.data)
return new_f
def __mul__(self, other):
"""Construct a new factor by multiplying `self` by `other`"""
# Figure out the variables and domains for the new factor
new_domains = dict(self.domains)
for name,domain in other.domains.items():
if name not in new_domains:
new_domains[name] = domain
elif self.domain(name) != other.domain(name):
raise ValueError(f"Incompatible domains for {repr(name)}: "
f"{repr(self.domain(name))} versus "
f"{repr(other.domain(name))}")
# Empty factor with the computed domains
new_f = Factor(new_domains)
# Perform the multiplications
for k in new_f.keys:
h = dict(zip(new_f.names, k))
k1 = tuple(h[name] for name in self.names)
k2 = tuple(h[name] for name in other.names)
new_f[k] = self[k1] * other[k2]
return new_f
# ------- Accessors
@property
def names(self):
"""Return the names of all the variable in the table"""
return tuple(self.domains.keys())
@property
def keys(self):
"""Iterate over all value combinations for all variables in table"""
return tuple(itertools.product(*self.domains.values()))
@property
def size(self):
return self.data.size
def domain(self, name):
"""Return the domain of values for variable `name`"""
return tuple(self.domains[name])
def __getitem__(self, key):
"""Return the table entry for the tuple of values `key`"""
if type(key) is not tuple:
key = (key,)
if len(key) != len(self.names):
raise ValueError(f"Wrong number of arguments:"
f"{len(key)} instead of {len(self.names)}")
idx = tuple(self._idx(name,val) for (name,val) in zip(self.names, key))
return self.data[idx]
def __setitem__(self, key, new_val):
"""Set the table entry for the tuple of values `key` to `new_val`"""
if len(key) != len(self.names):
raise ValueError(f"Wrong number of arguments: "
f"{len(key)} instead of {len(self.names)}")
idx = tuple(self._idx(name,val) for (name,val) in zip(self.names, key))
self.data[idx] = new_val
def _idx(self, name, val):
"""Return the index of `val` in `name`s domain"""
try:
return self.domains[name].index(val)
except ValueError:
raise ValueError(f"{repr(val)} is not in domain of {repr(name)}")
# ------- Standard overrides for pretty printing
def __repr__(self):
cls = self.__class__.__name__
return f"<{cls} object: names={list(self.names)}, rows={self.size}>"
def __str__(self):
w = 0
for k in self.keys:
for v in k:
w = max(w, len(str(v)))
fmt = f"%{w}s " * len(self.names)
out = fmt % tuple(self.names) + "value "
out += fmt % tuple("-"*w for n in self.names) + "-----"
for k in self.keys:
out += " "
out += fmt % k
out += f"{self[k]}"
return out
Step by Step Solution
There are 3 Steps involved in it
Step: 1
Get Instant Access to Expert-Tailored Solutions
See step-by-step solutions with expert insights and AI powered tools for academic success
Step: 2
Step: 3
Ace Your Homework with AI
Get the answers you need in no time with our AI-driven, step-by-step assistance
Get Started