NN Class¶
The unified neural network class for feedforward, convolutional, recurrent, segmentation, and ODE-based architectures.
Constructors¶
nn = NN(Layers)
nn = NN(Layers, Connections)
nn = NN(Layers, Connections, inputSize, outputSize, name)
Layers– cell array of layer objects (FullyConnectedLayer, ReluLayer, etc.)Connections– table defining DAG connectivity (optional; omit for sequential networks)inputSize,outputSize– dimension arraysname– string identifier
Properties¶
Property |
Type |
Description |
|---|---|---|
|
cell array |
Ordered list of layer objects |
|
table |
DAG connectivity (source/destination layer names) |
|
string |
Network name |
|
int |
Number of layers |
|
int |
Total neuron count across all layers |
|
array |
Input dimensions |
|
array |
Output dimensions |
|
cell array |
Computed reachable set per layer (after calling reach) |
|
array |
Computation time per layer (after calling reach) |
Methods¶
Method |
Description |
|---|---|
|
Forward pass on a single concrete input vector/image/tensor |
|
Compute reachable output set from an input set (Star, ImageStar, etc.) |
|
Verify classification robustness. Returns: 1=robust, 0=not robust, 2=unknown. For exact methods, unknown is treated as not robust (returns 0). |
|
Verify properties from a VNNLIB file. Returns result (0/1/2) and the input set X. |
|
Verify safety against a HalfSpace unsafe region |
|
Verify pixel-level robustness for semantic segmentation networks |
|
Classify input(s); handles both single points and sets |
|
Check if a precomputed reachable set satisfies robustness for target class |
|
Find counterexample inputs violating specification via random sampling |
|
Verify robustness for RNN/sequence classification over time steps |
Example¶
% Build a simple network
layers = {FullyConnectedLayer(W1, b1), ReluLayer(), FullyConnectedLayer(W2, b2)};
net = NN(layers);
% Evaluate
y = net.evaluate([0.5; 0.5]);
% Verify
input_set = Star([-1; -1], [1; 1]);
reachOptions.reachMethod = 'approx-star';
output_sets = net.reach(input_set, reachOptions);
result = net.verify_robustness(input_set, reachOptions, 1);