Preserving the colours in a STEP file when modifying the geometry in Open Cascade

I'm writing a script in python using Open Cascade Technology (using the pyOCCT package for Anaconda) to import STEP files, defeature them procedurally and re-export them. I want to preserve the product hierarchy, names and colours as much as possible. Currently the script can import STEP files, simplify all of the geometry while roughly preserving the hierarchy and re-export the step file. The problem is no matter how I approach the problem, I can't manage to make it preserve the colours of the STEP file in a few particular cases.
Here's the model I pass in to the script: Input

And here's the result of the simplification: Simplified

In this case, the simplification has worked correctly but the colours of some of the bodies were not preserved. The common thread is that the bodies that loose their colours are children of products which only have other bodies as their children (ie: they don't contain sub-products). This seems to be related to the way that Open Cascade imports STEP files which are translated as follows:
OCCT translation results

Alright, now for some code (I can provide some C++ snippets if needed):

from OCCT.STEPControl import STEPControl_Reader, STEPControl_Writer, STEPControl_AsIs
from OCCT.BRepAlgoAPI import BRepAlgoAPI_Defeaturing
from OCCT.TopAbs import TopAbs_FACE, TopAbs_SHAPE, TopAbs_COMPOUND
from OCCT.TopExp import TopExp_Explorer
from OCCT.ShapeFix import ShapeFix_Shape
from OCCT.GProp import GProp_GProps
from OCCT.BRepGProp import BRepGProp
from OCCT.TopoDS import TopoDS
from OCCT.TopTools import TopTools_ListOfShape
from OCCT.BRep import BRep_Tool
from OCCT.Quantity import Quantity_ColorRGBA
from OCCT.ShapeBuild import ShapeBuild_ReShape

from OCCT.STEPCAFControl import STEPCAFControl_Reader, STEPCAFControl_Writer
from OCCT.XCAFApp import XCAFApp_Application
from OCCT.XCAFDoc import XCAFDoc_DocumentTool, XCAFDoc_ColorGen, XCAFDoc_ColorSurf 
from OCCT.XmlXCAFDrivers import XmlXCAFDrivers
from OCCT.TCollection import TCollection_ExtendedString
from OCCT.TDF import TDF_LabelSequence
from OCCT.TDataStd import TDataStd_Name
from OCCT.TDocStd import TDocStd_Document
from OCCT.TNaming import TNaming_NamedShape
from OCCT.Interface import Interface_Static

# DBG
def export_step(shape, path):
    writer = STEPControl_Writer()
    writer.Transfer( shape, STEPControl_AsIs )
    writer.Write(path)

# DBG
def print_shape_type(label, shapeTool):
    if shapeTool.IsFree_(label):
        print("Free")
    if shapeTool.IsShape_(label):
        print("Shape")
    if shapeTool.IsSimpleShape_(label):
        print("SimpleShape")
    if shapeTool.IsReference_(label):
        print("Reference")
    if shapeTool.IsAssembly_(label):
        print("Assembly")
    if shapeTool.IsComponent_(label):
        print("Component")
    if shapeTool.IsCompound_(label):
        print("Compound")
    if shapeTool.IsSubShape_(label):
        print("SubShape")

# Returns a ListOfShape containing the faces to be removed in the defeaturing
# NOTE: For concisness I've simplified this algorithm and as such it *MAY* not produce exactly 
# the same output as shown in the screenshots but should still do SOME simplification
def select_faces(shape):
    exp = TopExp_Explorer(shape, TopAbs_FACE)
    selection = TopTools_ListOfShape()
    nfaces = 0
    while exp.More():
        rgb = None
        s = exp.Current()
        exp.Next()
        nfaces += 1

        face = TopoDS.Face_(s)
        gprops = GProp_GProps()
        BRepGProp.SurfaceProperties_(face, gprops)
        area = gprops.Mass()

        surf = BRep_Tool.Surface_(face)

        if area < 150:
            selection.Append(face)
            #log(f"\t\tRemoving face with area: {area}")

    return selection, nfaces

# Performs the defeaturing
def simplify(shape):
    defeaturer = BRepAlgoAPI_Defeaturing()
    defeaturer.SetShape(shape)

    sel = select_faces(shape)
    if sel[0].Extent() == 0:
        return shape

    defeaturer.AddFacesToRemove(sel[0])
    defeaturer.SetRunParallel(True)
    defeaturer.SetToFillHistory(False)
    defeaturer.Build()

    if (not defeaturer.IsDone()):
        return shape# TODO: Handle errors
    return defeaturer.Shape()

# Given the label of an entity it finds it's displayed colour. If the entity has no defined colour the parents are searched for defined colours as well.
def find_color(label, colorTool):
    col = Quantity_ColorRGBA()
    status = False
    while not status and label != None:
        try:
            status = colorTool.GetColor(label, XCAFDoc_ColorSurf, col)
        except:
            break
        label = label.Father()
    return (col.GetRGB().Red(), col.GetRGB().Green(), col.GetRGB().Blue(), col.Alpha(), status, col)

# Finds all child shapes and simplifies them recursively. Returns true if there were any subshapes.
# For now this assumes all shapes passed into this are translated as "SimpleShape". 
# "Assembly" entities should be skipped as we don't need to touch them, "Compound" entities should work with this as well, though the behaviour is untested. 
# Use the print_shape_type(shapeLabel, shapeTool) method to identify a shape.
def simplify_subshapes(shapeLabel, shapeTool, colorTool, set_colours=None):
    labels = TDF_LabelSequence()
    shapeTool.GetSubShapes_(shapeLabel, labels)
    #print_shape_type(shapeLabel, shapeTool)
    #print(f"{shapeTool.GetShape_(shapeLabel).ShapeType()}")
    cols = {}

    for i in range(1, labels.Length()+1):
        label = labels.Value(i)
        currShape = shapeTool.GetShape_(label)
        print(f"\t{currShape.ShapeType()}")
        if currShape.ShapeType() == TopAbs_COMPOUND:
            # This code path should never be taken as far as I understand
            simplify_subshapes(label, shapeTool, colorTool, set_colours)
        else:
            ''' See the comment at the bottom of the main loop for an explanation of the function of this block
            col = find_color(label, colorTool)
            #print(f"{name} RGBA: {col[0]:.5f} {col[1]:.5f} {col[2]:.5f} {col[3]:.5f} defined={col[4]}")
            cols[label.Tag()] = col

            if set_colours != None:
                colorTool.SetColor(label, set_colours[label.Tag()][5], XCAFDoc_ColorSurf)'''

            # Doing both of these things seems to result in colours being reset but the geometry doesn't get replaced
            nshape = simplify(currShape)
            shapeTool.SetShape(label, nshape) # This doesn't work

    return labels.Length() > 0, cols

# Set up XCaf Document
app = XCAFApp_Application.GetApplication_()
fmt = TCollection_ExtendedString('MDTV-XCAF')
doc = TDocStd_Document(fmt)
app.InitDocument(doc)

shapeTool = XCAFDoc_DocumentTool.ShapeTool_(doc.Main())
colorTool = XCAFDoc_DocumentTool.ColorTool_(doc.Main())

# Import the step file
reader = STEPCAFControl_Reader()
reader.SetNameMode(True)
reader.SetColorMode(True)
Interface_Static.SetIVal_("read.stepcaf.subshapes.name", 1) # Tells the importer to import subshape names

reader.ReadFile("testcolours.step")
reader.Transfer(doc)
labels = TDF_LabelSequence()
shapeTool.GetShapes(labels)

# Simplify each shape that was imported
for i in range(1, labels.Length()+1):
    label = labels.Value(i)
    shape = shapeTool.GetShape_(label)

    # Assemblies are just made of other shapes, so we'll skip this and simplify them individually...
    if shapeTool.IsAssembly_(label):
        continue

    # This function call here is meant to be the fix for the bug described.
    # The idea was to check if the TopoDS_Shape we're looking at is a COMPOUND and if so we would simplify and call SetShape() 
    # on each of the sub-shapes instead in an attempt to preserve the colours stored in the sub-shape's labels.
    #status, loadedCols = simplify_subshapes(label, shapeTool, colorTool)
    #if status:
        #continue

    shape = simplify(shape)
    shapeTool.SetShape(label, shape)

    # The code gets a bit messy here because this was another attempt at fixing the problem by building a dictionary of colours 
    # before the shapes were simplified and then resetting the colours of each subshape after simplification. 
    # This didn't work either.
    # But the idea was to call this function once to generate the dictionary, then simplify, then call it again passing in the dictionary so it could be re-applied.
    #if status:
    #    simplify_subshapes(label, shapeTool, colorTool, loadedCols)

shapeTool.UpdateAssemblies()

# Re-export
writer = STEPCAFControl_Writer()
Interface_Static.SetIVal_("write.step.assembly", 2) 
Interface_Static.SetIVal_("write.stepcaf.subshapes.name", 1)
writer.Transfer(doc, STEPControl_AsIs)
writer.Write("testcolours-simplified.step")

There's a lot of stuff here for a minimum reproducible example but the general flow of the program is that we import the step file:

reader.ReadFile("testcolours.step")
reader.Transfer(doc)

Then we iterate through each label in the file (essentially every node in the tree):

labels = TDF_LabelSequence()
shapeTool.GetShapes(labels)

# Simplify each shape that was imported
for i in range(1, labels.Length()+1):
    label = labels.Value(i)
    shape = shapeTool.GetShape_(label)

We skip any labels marked as assemblies since they contain children and we only want to simplify individual bodies. We then call simplify(shape) which performs the simplification and returns a new shape, we then call shapeTool.SetShape() to bind the new shape to the old label.
The thing that doesn't work here is that as explained, Component3 and Component4 don't get marked as Assemblies and are treated as SimpleShapes and when they are simplified as one shape, the colours are lost.

One solution I attempted was to call a method simplify_subshapes() which would iterate through each of the subshapes, and do the same thing as the main loop, simplifying them and then calling SetShape(). This ended up being even worse as it resulted in those bodies not being simplified at all but still loosing their colours.

I also attempted to use the simplify_subshapes() method to make a dictionary of all the colours of the subshapes, then simplify the COMPOUND shape and then call the same method again to this time re-apply the colours to the subshapes using the dictionary (the code for this is commented out with an explanation as to what it did).

            col = find_color(label, colorTool)
            #print(f"{name} RGBA: {col[0]:.5f} {col[1]:.5f} {col[2]:.5f} {col[3]:.5f} defined={col[4]}")
            cols[label.Tag()] = col

            if set_colours != None:
                colorTool.SetColor(label, set_colours[label.Tag()][5], XCAFDoc_ColorSurf)

As far as I see it the issue could be resolved either by getting open cascade to import Component3 and Component4 as Assemblies OR by finding a way to make SetShape() work as intended on subshapes.

Here's a link to the test file: testcolours.step

This is a cross post of the question I posted on StackOverflow because I don't get the impression many OCCT devs hang around there.

Preserving the colours in a STEP file when modifying the geometry in Open Cascade