Why would anyone do this?

I have been playing around with WebGPU recently, writing a toy engine in JS and running it on Chrome. The debugging situation is great now, PIX just works and I can usually identify and fix issues quickly when they arise. However, when it comes to profiling, well the situation is not that great. For instance, AMD’s Radeon Developer Panel will detect Chrome as a DX12 app (when using the proper flags), it will attach to the process and even let me start a capture. The problem is this capture will never complete. As someone who profiles GPU workloads for a living, it is very frustrating not being able to do this in my own personal project. This post is the story of how I fixed that by shimming d3d12.dll and badgering Marco Castorina and Ray Dey with questions. ( Thanks guys!) You can find the code here and learn how to use it in my companion blog post GPU profiling for WebGPU workloads on Windows with Chrome.

Figuring out what’s missing

From the get-go, I had the intuition that the issue lied somewhere around presenting the image to the screen. I knew that GPU profilers usually rely on present events to determine what workload makes up a frame. Since the profiler would attach and let me start a capture but never complete, it seemed fairly natural to assume it never caught the triggering event.

The first thing I did was to ask the Chrome people if there was any chance to force Chrome to present through DX12. Sadly, there is no development flag or anything of the sort available. I’m not entirely sure how Chrome presents the image in the end but I know it can be complicated and that the WebGPU swapchain image is in the end just a texture shared with the main process that does the final compositing*. ( Don’t quote me on this, Chrome people feel free to correct me )

However, they pointed me towards an arcane DX12 feature they use to let PIX know when the present event would happen if they were to actually present.

void Texture::NotifySwapChainPresentToPIX() {
    // In PIX's D3D12-only mode, there is no way to determine frame boundaries
    // for WebGPU since Dawn does not manage DXGI swap chains. Without assistance,
    // PIX will wait forever for a present that never happens.
    // If we know we're dealing with a swapbuffer texture, inform PIX we've
    // "presented" the texture so it can determine frame boundaries and use its
    // contents for the UI.
    if (mSwapChainTexture) {
        ID3D12SharingContract* d3dSharingContract =
            ToBackend(GetDevice()->GetQueue())->GetSharingContract();
        if (d3dSharingContract != nullptr) {
            d3dSharingContract->Present(mResourceAllocation.GetD3D12Resource(), 0, 0);
        }
    }
}

What the hell is a ID3D12SharingContract you ask? Honestly I’m still not sure, feel free to read the documentation. No matter what it is, the following bit is not super reassuring.

Note that this interface is not supported when there are no diagnostic tools present, so your application mustn’t rely on it existing.

That kind of sounds PIX specific to me. How would I even let the runtime know?

I am the captain now meme, captain replaced with diagnotics tool

Let’s assume for now that this thing is working and that Chrome calls its Present function. How would AMD’s profiler catch it? The profiler works at the user mode driver level, so the D3D12 runtime would have to pass that information all the way down. I already know that the runtime doesn’t pass things like PIX markers, because for RGP to see markers we need to pass them through AMD’s AGS library. This does not bode well.

I’m not a driver engineer but out of curiosity I looked at what Microsoft exposes in its documentation and it’s surprisingly decent? At least as an outside observer who’s never tried to actually implement a driver. It tells you which functions the driver should implement and which runtime callbacks the driver can call. Unsurprinsingly, no mention of sharing contracts and even then, that would mean modifying the user mode driver.

At this point I was a bit frustrated that the only thing that seemed to be missing was a present to catch my workload. That’s when I started to think, what if I present something. Anything really, it doesn’t matter. If it’s from the right process, the right device, the right something maybe that would work. I would have to somehow be able to know when ID3D12SharingContract::Present is called and to insert an actual Present call instead. I was starting to look into Microsoft Detours (which I believe could achieve the same result) when Marco reminded me that GFXReconstruct exists, does exactly what I want and is open source!

GFXReconstruct is a library that is able to capture every single call made to a graphics API, record and replay them later for analysis. It is very powerful and even lets you capture multiple frames of an AAA game. The way it works is by shimming the d3d12 and dxgi runtimes. It replaces the original d3d12.dll and dxgi.dll files, exposes the same functions, does some extra work and then calls the actual functions in the original DLL. It’s brilliant and exactly what I need. If I can wrap the ID3D12SharingContract object in my shim and get Chrome to load my DLL, I should be able to call the actual present function which would trigger the profiler. Honestly it was a long shot and with implementation setbacks set aside it worked flawlessly on the first try! Not only did it fix AMD’s profiler but Nvidia’s as well.

Implementation

Fair warning, most of this is a blatant copy of what GFXReconstruct does. I haven’t invented anything, I’ve just straight up stolen GFXReconstruct’s code, simplified it and inserted the bits of code I needed in the right places.

Creating the DLL

Alright, so we’ll be building a DLL. We need to create a d3d12.def file where we define the functions we would like to expose. GFXReconstruct exposes all of the d3d12.dll functions, I don’t need that many. I have only exposed the ones that Chrome requires in my def file.

LIBRARY d3d12
EXPORTS
D3D12CreateDevice=D3D12CreateDevice_webgpu_shim @101
D3D12GetDebugInterface=D3D12GetDebugInterface_webgpu_shim @102
D3D12SerializeRootSignature=D3D12SerializeRootSignature_webgpu_shim @115
D3D12CreateRootSignatureDeserializer=D3D12CreateRootSignatureDeserializer_webgpu_shim @107
D3D12SerializeVersionedRootSignature=D3D12SerializeVersionedRootSignature_webgpu_shim @116
D3D12CreateVersionedRootSignatureDeserializer=D3D12CreateVersionedRootSignatureDeserializer_webgpu_shim @108

We now need to define those functions in d3d12_webgpu_shim.cpp. For now they don’t have to do anything, but they need to be present otherwise the DLL won’t compile. It’s also trivial to define them as their prototypes is exposed in d3d12.h, just return S_OK in the body of the function and the compiler should be happy. For instance, our shim D3D12CreateDevice_webgpu_shim function would look like this.

HRESULT WINAPI D3D12CreateDevice_webgpu_shim(
    _In_opt_ IUnknown* pAdapter,
    D3D_FEATURE_LEVEL MinimumFeatureLevel,
    _In_ REFIID riid, // Expected: ID3D12Device
    _COM_Outptr_opt_ void** ppDevice)
{
    return S_OK;
}

Then we need to define what happens when our DLL is loaded, this is done through a function called DllMain. We can already test if Chrome will load our DLL by displaying a message box within this function and placing the DLL next to chrome.exe.

BOOL WINAPI DllMain(
    HINSTANCE hinstDLL,
    DWORD fdwReason,
    LPVOID lpvReserved)
{
    MessageBox(NULL, "I'm the runtime now!", "Look at me", NULL);

    return TRUE;
}

Note that Chrome won’t load our DLL by default, probably as a security measure. It turns out that launching Chrome with the same flags required to take a PIX capture fixes that issue --disable-gpu-sandbox --disable-gpu-watchdog.

A picture of the message box over the webgpu samples page.

Success! Chrome does load my DLL and I am the runtime now. Small problem though, the page crashes, nothing is displayed. That is not surprising as our DLL is not doing anything apart from displaying a message box yet!

Implementing the functions

Fixing that issue is fairly easy. All we have to do is to make our shim functions passthrough, as in they should call the actual functions from the regular d3d12.dll. For this, we need to first load the regular DLL from the system directory. We can then gather the function pointers we’re interested in from this DLL and store them for later. This is all happening in our DllMain function.

if (fdwReason == DLL_PROCESS_ATTACH)
{
    char systemDirectoryChar[MAX_PATH];
    UINT result = GetSystemDirectory(systemDirectoryChar, MAX_PATH);
    if (result == 0) {
        return FALSE;
    }

    std::string coreLibDirectory(systemDirectoryChar);
    d3d12Handle = LoadLibraryA((coreLibDirectory + "\\d3d12.dll").c_str());
    if (d3d12Handle == NULL)
    {
        return FALSE;
    }

    g_D3D12CreateDevice = reinterpret_cast<PFN_D3D12_CREATE_DEVICE>(GetProcAddress(d3d12Handle, "D3D12CreateDevice"));
    g_D3D12GetDebugInterface = reinterpret_cast<PFN_D3D12_GET_DEBUG_INTERFACE>(GetProcAddress(d3d12Handle, "D3D12GetDebugInterface"));
    g_D3D12SerializeRootSignature = reinterpret_cast<PFN_D3D12_SERIALIZE_ROOT_SIGNATURE>(GetProcAddress(d3d12Handle, "D3D12SerializeRootSignature"));
    g_D3D12CreateRootSignatureDeserializer = reinterpret_cast<PFN_D3D12_CREATE_VERSIONED_ROOT_SIGNATURE_DESERIALIZER>(GetProcAddress(d3d12Handle, "D3D12CreateRootSignatureDeserializer"));
    g_D3D12SerializeVersionedRootSignature = reinterpret_cast<PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE>(GetProcAddress(d3d12Handle, "D3D12SerializeVersionedRootSignature"));
    g_D3D12CreateVersionedRootSignatureDeserializer = reinterpret_cast<PFN_D3D12_CREATE_VERSIONED_ROOT_SIGNATURE_DESERIALIZER>(GetProcAddress(d3d12Handle, "D3D12CreateVersionedRootSignatureDeserializer"));
    g_D3D12EnableExperimentalFeatures = reinterpret_cast<PFN_D3D12_ENABLE_EXPERIMENTAL_FEATURES>(GetProcAddress(d3d12Handle, "D3D12EnableExperimentalFeatures"));
}

Making a function passthrough is then just a matter of using those pointers to call the functions from the original d3d12.dll.

HRESULT WINAPI D3D12GetDebugInterface_webgpu_shim(
    _In_ REFIID riid,
    _COM_Outptr_opt_ void** ppvDebug)
{
    return g_D3D12GetDebugInterface(riid, ppvDebug);
}

If all went well, with those changes we should now not only own the runtime but also not make the page crash! Before moving on to the interesting part, let’s not forget to free the library we loaded when the DLL detaches from the process in DllMain.

else if (fdwReason == DLL_PROCESS_DETACH)
{
    FreeLibrary(d3d12Handle);
}

Wrapping the device

Cool, cool… cool, cool, cool. We ain’t any closer to profiling the GPU yet. True! But now that the device creation goes through our code, we can wrap it and insert code in any of its functions! We’ll do this in the D3D12CreateDevice_webgpu_shim function. First we will call the actual D3D12CreateDevice function. Then we will create our wrapper object from the actual ID3D12Device* (casted to an IUnknown*, more on that later) and finally we will return the pointer to our wrapper to Chrome (through the inout ppDevice argument) as well as the return code provided by D3D12CreateDevice.

HRESULT WINAPI D3D12CreateDevice_webgpu_shim(
    _In_opt_ IUnknown* pAdapter,
    D3D_FEATURE_LEVEL MinimumFeatureLevel,
    _In_ REFIID riid, // Expected: ID3D12Device
    _COM_Outptr_opt_ void** ppDevice)
{
    HRESULT res = g_D3D12CreateDevice(pAdapter, MinimumFeatureLevel, riid, ppDevice);

    if (res == S_OK)
    {
        ID3D12Device_webgpu_shim* device = new ID3D12Device_webgpu_shim(reinterpret_cast<IUnknown*>(*ppDevice));
        *ppDevice = device;
    }

    return res;
}

Our wrapper object is also fairly simple. We can look at the ID3D12Device interface in d3d12.h and create a class that implements this interface. We can already note that this interface inherits from ID3D12Object which itself inherits from IUnknown. We will need to wrap all of those as well. This is also the reason why we casted *ppDevice to IUnknown* earlier, this pointer will be stored as a class attribute in IUnknown_webgpu_shim under the name object_. All wrapped classes will store the pointer to the original object as an IUnknown* there.

We also need to add a constructor to our ID3D12Device_webgpu_shim class that takes an IUnknown* as a parameter to pass it on to the parent class’ constructor until it reaches IUnknown_webgpu_shim.

class ID3D12Device_webgpu_shim : public ID3D12Object_webgpu_shim
{
public:
    ID3D12Device_webgpu_shim(IUnknown* wrapped_object) : ID3D12Object_webgpu_shim(wrapped_object) {};
///...
}

The wrapped functions’ implementations are now trivial, all we have to do is cast object_ to the correct pointer type and call the regular function on it. Here is one example but we need to do this for all the functions defined in the ID3D12Device interface.

virtual HRESULT STDMETHODCALLTYPE CreateGraphicsPipelineState(
    _In_  const D3D12_GRAPHICS_PIPELINE_STATE_DESC* pDesc,
    REFIID riid,
    _COM_Outptr_  void** ppPipelineState)
{
    return reinterpret_cast<ID3D12Device*>(object_)->CreateGraphicsPipelineState(pDesc, riid, ppPipelineState);
}

We will need to repeat the exact same process for ID3D12Object_webgpu_shim and IUnknown_webgpu_shim. The only difference for the later is that it needs to define object_ as a protected class attribute. Hopefully, if all went well our shim is still working properly. We are not doing anything interesting yet but we are 100% passthrough.

Memory cleanup

The astute reader will notice however that we have allocated memory for our wrapper ID3D12Device_webgpu_shim but have never freed it. We are also not keeping any reference to it so we can’t really call delete when we’re done. What we know though, is that to free a device in D3D12, we call device->Release(). And as it turns out, we did implement Release in our IUnknown_webgpu_shim wrapper.

If we knew that there were no references to the wrapper left, we could probably delete it from there right? I am far from an expert on this but apparently this is classic strategy! All we have to do is add a refCount_ attribute to our wrapper, initialize it to one on creation, increase it by one in the AddRef function, decrease it by one in the Release function and when it reaches zero, call delete this.

Small problem, if we want the destructors from the entire chain to be called, we need to declare IUnknown_webgpu_shim’s destructor as virtual. And that for some reason seems to break the shim. I’m not 100% sure why but it looks like it messes up the vtable, something something it’s not described in the IUnknown interface? (CPP experts let me know) My solution to this is to override the Release function in the child class. It explicitely calls IUnknown_webgpu_shim::Release() first, and then calls delete this if refCount_ is zero. Note that here it doesn’t matter as our ID3D12Device_webgpu_shim class doesn’t need to do anything in its destructor but it will become important later.

virtual ULONG STDMETHODCALLTYPE Release() override
{
    HRESULT res = IUnknown_webgpu_shim::Release();

    if (refCount_ == 0)
    {
        delete this;
    }

    return res;
}

Catching the sharing contract

Ok so far we have implemented a wrapper for the ID3D12Device interface, but we’re still not catching the creation of the sharing contract. Looking some more at Chrome’s code, we can see that this sharing contract is created by calling the As method of a ComPtr<ID3D12CommandQueue>.

// If PIX is not attached, the QueryInterface fails. Hence, no need to check the return
// value.
mCommandQueue.As(&mD3d12SharingContract);

That’s an excellent news! The As function internally calls the QueryInterface function which is defined in IUnknown, and ID3D12CommandQueue inherits from the IUnknown interface. We can thus simply create a command queue wrapper that overrides the QueryInterface function. In there, we can check whether rrid == IID_ID3D12SharingContract, and if that’s the case we can instanciate our sharing contract wrapper the same way we did for the device.

virtual HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void** object)
{
    HRESULT res = object_->QueryInterface(riid, object);

    if (riid == IID_ID3D12SharingContract)
    {
        ID3D12SharingContract_webgpu_shim* sharingContract = new ID3D12SharingContract_webgpu_shim(reinterpret_cast<IUnknown*>(*object));
        *object = sharingContract;
    }
    return res;
}

Please note that our sharing contract wrapper needs to be slightly different from the other wrappers because the underlying object might not exist.

This interface is not supported when there are no diagnostic tools present, so your application mustn’t rely on it existing.

In an attempt to make that DLL compatible with PIX I’ve decided to still try and call the actual functions when the call to QueryInterface succeeded but honestly I might as well have done nothing in those.

virtual void STDMETHODCALLTYPE SharedFenceSignal(
    _In_  ID3D12Fence* pFence,
    UINT64 FenceValue)
{
    if (object_ != nullptr)
    {
        return reinterpret_cast<ID3D12SharingContract*>(object_)->SharedFenceSignal(pFence, FenceValue);
    }
}

Presenting for real

Phew, so we finally managed to catch the creation of the sharing contract, wrap it and make it passthrough. At that point I added a message box in the Present function and validated that I got one every frame. We’ve never been this close to trying our theory. Will calling the actual Present function trigger the profilers? What are they listening to? Any present in the same process? Presenting requires a swapchain, what swapchain are we presenting? Do profilers listen to any swapchain that is attached to the command queue that ran the work? Any command queue? Creating a swapchain usually requires a window, should I create a new window? So many questions!

Let’s take a deep breath and try the easiest thing. My first attempt was to create a swapchain the classic way, using CreateSwapChainForHwnd with a window and everything. And yes, I did spend some time figuring out that the parameter called pDevice expects a ID3D12CommandQueue, obviously. It was a bit of a mess, I needed to store the HINSTANCE of the DLL, define my window class and so on. I actually finished the prototype with this implementation but I later changed for CreateSwapChainForComposition which makes everything simpler so we’ll pretend I did that on the first go. This function doesn’t require a window and works perfectly fine for my use case :).

Creating the swapchain itself is pretty straight forward. You will need a IDXGIFactory4 which you can get by calling DXGICreateFactory. This is a DXGI function though, so you’ll need to load dxgi.dll in DllMain, get a pointer to that function and store it somewhere, much like for D3D12CreateDevice and others.

    dxgiHandle = LoadLibraryA((coreLibDirectory + "\\dxgi.dll").c_str());
    if (dxgiHandle == NULL)
    {
        return FALSE;
    }
    g_DXGICreateFactory = reinterpret_cast<PFN_CREATE_DXGI_FACTORY>(GetProcAddress(dxgiHandle, "CreateDXGIFactory"));

On top of this factory you’ll need a swapchain descriptor, I’ve made it as simple as possible. It’s a 1 by 1 pixel R8G8B8A8_UNORM render target with 2 buffers and SampleDesc.Count set to one. I have also set the scaling to DXGI_SCALING_STRETCH and the swap effect to DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL as mandated by the docs.

Now to create this swapchain we still need to pass it a ID3D12CommandQueue as the first parameter. The simplest solution for me was to create a swapchain every time a new command queue is created. This is super easy, let’s just move the code we described above in ID3D12CommandQueue_webgpu_shim’s constructor and store a pointer to the swapchain as a class member. Note that we’ll need to call swapchain_->Release(); in the destructor to free it. Good thing we checked that our entire chain of destructors was actually called earlier!

ID3D12CommandQueue_webgpu_shim(IUnknown* wrapped_object) : swapChain_(nullptr), ID3D12Pageable_webgpu_shim (wrapped_object)
{
    ComPtr<IDXGIFactory4> factory;
    HRESULT hr = g_DXGICreateFactory(IID_PPV_ARGS(&factory));
    if (hr != S_OK)
    {
        return;
    }

    DXGI_SWAP_CHAIN_DESC1 swapChainDesc = {};
    swapChainDesc.Width = 1;
    swapChainDesc.Height = 1;
    swapChainDesc.Format = DXGI_FORMAT_R8G8B8A8_UNORM;
    swapChainDesc.BufferCount = 2;
    swapChainDesc.Scaling = DXGI_SCALING_STRETCH;
    swapChainDesc.SwapEffect = DXGI_SWAP_EFFECT_FLIP_SEQUENTIAL;
    swapChainDesc.SampleDesc.Count = 1;

    factory->CreateSwapChainForComposition((IUnknown*)object_, &swapChainDesc, nullptr, &swapChain_);
};

I actually wonder if we could get away with creating our own device, command queue and swapchain once and for all in DllMain and always present from that one swapchain. I might come back and update this post if I get around to it. For now, this is left as an exercise to the reader :p.

Alright we’re very very close from the end here. We have a swapchain per command queue from which we can present but what we actually want is to present from the ID3D12SharingContract_webgpu_shim class. Not a problem. The ID3D12SharingContract_webgpu_shim object is created in the QueryInterface function of ID3D12CommandQueue_webgpu_shim. All we have to do is pass on the swapchain_ class member to ID3D12SharingContract_webgpu_shim’s constructor and store it there as a class attribute as well.

class ID3D12SharingContract_webgpu_shim: public IUnknown_webgpu_shim
{
public:
    ID3D12SharingContract_webgpu_shim(IUnknown* wrapped_object, IDXGISwapChain1* swapChain) : swapChain_(swapChain), IUnknown_webgpu_shim(wrapped_object) {};
// ...
private:
    IDXGISwapChain1* swapChain_;
}

The last thing we need to do is to implement the Present function in ID3D12SharingContract_webgpu_shim, which is as simple as calling swapChain_->Present(0, 0);.

virtual void STDMETHODCALLTYPE Present(
    _In_  ID3D12Resource * pResource,
    UINT Subresource,
    _In_  HWND window)
{
    if (swapChain_ != nullptr)
    {
        swapChain_->Present(0, 0);
    }
}

Closing thoughts

And I think that’s it folks. Once I had implemented this, it just worked. I thought it would be the hard part but nope, it did work on the first try! All there is to do now is to learn how to use it with my companion blog post GPU profiling for WebGPU workloads on Windows with Chrome, get some xp using a GPU profiler and optimize the hell out of all the web based renderers you cross path with!

Thanks a lot for reading thus far, please leave a comment if you have questions or anything to say. I’m eager to know what your thoughts are on this nice little hack 😁.